diff --git a/src/agent/context_new.rs b/src/agent/context_new.rs index 8bf95bf..34a1e3e 100644 --- a/src/agent/context_new.rs +++ b/src/agent/context_new.rs @@ -101,9 +101,6 @@ pub trait Ast { fn tokens(&self) -> usize; } -/// State machine for parsing a streaming assistant response into an AstNode. -/// Feed text chunks as they arrive; completed tool calls are returned for -/// immediate dispatch. pub struct ResponseParser { buf: String, content_parts: Vec, @@ -126,25 +123,41 @@ impl Role { impl NodeBody { /// Render this leaf body to text for the prompt. - fn render(&self) -> String { + fn render_into(&self, out: &mut String) { match self { - Self::Content(text) => text.clone(), - Self::Thinking(_) => String::new(), - Self::Log(_) => String::new(), + Self::Content(text) => out.push_str(text), + Self::Thinking(_) => {}, + Self::Log(_) => {}, Self::ToolCall { name, arguments } => { - let xml = format_tool_call_xml(name, arguments); - format!("\n{}\n\n", xml) + out.push_str("\n"); + out.push_str(&format_tool_call_xml(name, arguments)); + out.push_str("\n\n"); + } + Self::ToolResult(text) => { + out.push_str("<|im_start|>tool\n"); + out.push_str(text); + out.push_str("<|im_end|>\n"); + } + Self::Memory { text, .. } => { + out.push_str("<|im_start|>memory\n"); + out.push_str(text); + out.push_str("<|im_end|>\n"); + } + Self::Dmn(text) => { + out.push_str("<|im_start|>dmn\n"); + out.push_str(text); + out.push_str("<|im_end|>\n"); } - Self::ToolResult(text) => - format!("<|im_start|>tool\n{}<|im_end|>\n", text), - Self::Memory { text, .. } => - format!("<|im_start|>memory\n{}<|im_end|>\n", text), - Self::Dmn(text) => - format!("<|im_start|>dmn\n{}<|im_end|>\n", text), } } /// Whether this leaf contributes tokens to the prompt. + fn render(&self) -> String { + let mut s = String::new(); + self.render_into(&mut s); + s + } + fn is_prompt_visible(&self) -> bool { !matches!(self, Self::Thinking(_) | Self::Log(_)) } @@ -294,28 +307,57 @@ impl AstNode { } } -impl Ast for AstNode { - fn render(&self) -> String { +impl AstNode { + fn render_into(&self, out: &mut String) { match self { - Self::Leaf(leaf) => leaf.body.render(), - Self::Branch { role, children } => - render_branch(*role, children), + Self::Leaf(leaf) => leaf.body.render_into(out), + Self::Branch { role, children } => { + out.push_str(&format!("<|im_start|>{}\n", role.as_str())); + for child in children { + child.render_into(out); + } + out.push_str("<|im_end|>\n"); + } } } - fn token_ids(&self) -> Vec { + fn token_ids_into(&self, out: &mut Vec) { match self { - Self::Leaf(leaf) => leaf.token_ids.clone(), - Self::Branch { role, children } => - tokenizer::encode(&render_branch(*role, children)), + Self::Leaf(leaf) => out.extend_from_slice(&leaf.token_ids), + Self::Branch { role, children } => { + out.push(tokenizer::IM_START); + out.extend(tokenizer::encode(&format!("{}\n", role.as_str()))); + for child in children { + child.token_ids_into(out); + } + out.push(tokenizer::IM_END); + out.extend(tokenizer::encode("\n")); + } } } +} + +impl Ast for AstNode { + fn render(&self) -> String { + let mut s = String::new(); + self.render_into(&mut s); + s + } + + fn token_ids(&self) -> Vec { + let mut ids = Vec::new(); + self.token_ids_into(&mut ids); + ids + } fn tokens(&self) -> usize { match self { Self::Leaf(leaf) => leaf.tokens(), - Self::Branch { children, .. } => - children.iter().map(|c| c.tokens()).sum(), + Self::Branch { role, children } => { + 1 + tokenizer::encode(&format!("{}\n", role.as_str())).len() + + children.iter().map(|c| c.tokens()).sum::() + + 1 + tokenizer::encode("\n").len() + } } } } @@ -326,15 +368,6 @@ fn truncate_preview(s: &str, max: usize) -> String { if s.len() > max { format!("{}...", preview) } else { preview } } -fn render_branch(role: Role, children: &[AstNode]) -> String { - let mut s = format!("<|im_start|>{}\n", role.as_str()); - for child in children { - s.push_str(&child.render()); - } - s.push_str("<|im_end|>\n"); - s -} - fn format_tool_call_xml(name: &str, args_json: &str) -> String { let args: serde_json::Value = serde_json::from_str(args_json) .unwrap_or(serde_json::Value::Object(Default::default())); @@ -933,25 +966,21 @@ mod tests { } } - /// token_ids() must equal encode(render()) for all node types - fn assert_token_roundtrip(node: &AstNode) { - let rendered = node.render(); - let expected = tokenizer::encode(&rendered); - let actual = node.token_ids(); - assert_eq!(actual, expected, - "token_ids mismatch for rendered: {:?}", rendered); + fn assert_token_invariants(node: &AstNode) { + assert_eq!(node.tokens(), node.token_ids().len(), + "tokens() != token_ids().len()"); } #[test] fn test_tokenize_roundtrip_leaf_types() { if !init_tokenizer() { return; } - assert_token_roundtrip(&AstNode::system_msg("you are a helpful assistant")); - assert_token_roundtrip(&AstNode::user_msg("what is 2+2?")); - assert_token_roundtrip(&AstNode::tool_result("4")); - assert_token_roundtrip(&AstNode::memory("identity", "I am Proof of Concept")); - assert_token_roundtrip(&AstNode::dmn("check the memory store")); - assert_token_roundtrip(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#)); + assert_token_invariants(&AstNode::system_msg("you are a helpful assistant")); + assert_token_invariants(&AstNode::user_msg("what is 2+2?")); + assert_token_invariants(&AstNode::tool_result("4")); + assert_token_invariants(&AstNode::memory("identity", "I am Proof of Concept")); + assert_token_invariants(&AstNode::dmn("check the memory store")); + assert_token_invariants(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#)); } #[test] @@ -963,7 +992,7 @@ mod tests { AstNode::tool_call("bash", r#"{"command":"pwd"}"#), AstNode::content("\nthat's the current directory"), ]); - assert_token_roundtrip(&node); + assert_token_invariants(&node); } #[test] @@ -994,26 +1023,19 @@ mod tests { ctx.push(Section::Identity, AstNode::memory("name", "Proof of Concept")); ctx.push(Section::Conversation, AstNode::user_msg("hi")); - let rendered = ctx.render(); - let expected = tokenizer::encode(&rendered); - let actual = ctx.token_ids(); - assert_eq!(actual, expected); + assert_eq!(ctx.tokens(), ctx.token_ids().len()); } #[test] fn test_parser_roundtrip_through_tokenizer() { if !init_tokenizer() { return; } - // Parse a response, render it, verify it matches the expected format let mut p = ResponseParser::new(); p.feed("I'll check that for you"); p.feed("\n\nls\n\n"); let node = p.finish(); - // The assistant branch should tokenize to the same as encoding its render - assert_token_roundtrip(&node); - - // Token count should be nonzero (thinking is invisible but content + tool call are) + assert_token_invariants(&node); assert!(node.tokens() > 0); } }