Recursive render_into/token_ids_into, compose from cached children
render_into(&mut String) and token_ids_into(&mut Vec<u32>) recurse the tree extending the output in place. Branches emit their wrapping (im_start/role/im_end) and recurse into children — same structure in both methods. token_ids() now composes from cached leaf tokens instead of re-encoding the full rendered string. Killed the AstEvent/AstIter iterator experiment — explicit recursion is cleaner for a tree walk that isn't truly flattening. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
942144949d
commit
bb80225942
1 changed files with 80 additions and 58 deletions
|
|
@ -101,9 +101,6 @@ pub trait Ast {
|
|||
fn tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
/// State machine for parsing a streaming assistant response into an AstNode.
|
||||
/// Feed text chunks as they arrive; completed tool calls are returned for
|
||||
/// immediate dispatch.
|
||||
pub struct ResponseParser {
|
||||
buf: String,
|
||||
content_parts: Vec<String>,
|
||||
|
|
@ -126,25 +123,41 @@ impl Role {
|
|||
|
||||
impl NodeBody {
|
||||
/// Render this leaf body to text for the prompt.
|
||||
fn render(&self) -> String {
|
||||
fn render_into(&self, out: &mut String) {
|
||||
match self {
|
||||
Self::Content(text) => text.clone(),
|
||||
Self::Thinking(_) => String::new(),
|
||||
Self::Log(_) => String::new(),
|
||||
Self::Content(text) => out.push_str(text),
|
||||
Self::Thinking(_) => {},
|
||||
Self::Log(_) => {},
|
||||
Self::ToolCall { name, arguments } => {
|
||||
let xml = format_tool_call_xml(name, arguments);
|
||||
format!("<tool_call>\n{}\n</tool_call>\n", xml)
|
||||
out.push_str("<tool_call>\n");
|
||||
out.push_str(&format_tool_call_xml(name, arguments));
|
||||
out.push_str("\n</tool_call>\n");
|
||||
}
|
||||
Self::ToolResult(text) => {
|
||||
out.push_str("<|im_start|>tool\n");
|
||||
out.push_str(text);
|
||||
out.push_str("<|im_end|>\n");
|
||||
}
|
||||
Self::Memory { text, .. } => {
|
||||
out.push_str("<|im_start|>memory\n");
|
||||
out.push_str(text);
|
||||
out.push_str("<|im_end|>\n");
|
||||
}
|
||||
Self::Dmn(text) => {
|
||||
out.push_str("<|im_start|>dmn\n");
|
||||
out.push_str(text);
|
||||
out.push_str("<|im_end|>\n");
|
||||
}
|
||||
Self::ToolResult(text) =>
|
||||
format!("<|im_start|>tool\n{}<|im_end|>\n", text),
|
||||
Self::Memory { text, .. } =>
|
||||
format!("<|im_start|>memory\n{}<|im_end|>\n", text),
|
||||
Self::Dmn(text) =>
|
||||
format!("<|im_start|>dmn\n{}<|im_end|>\n", text),
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this leaf contributes tokens to the prompt.
|
||||
fn render(&self) -> String {
|
||||
let mut s = String::new();
|
||||
self.render_into(&mut s);
|
||||
s
|
||||
}
|
||||
|
||||
fn is_prompt_visible(&self) -> bool {
|
||||
!matches!(self, Self::Thinking(_) | Self::Log(_))
|
||||
}
|
||||
|
|
@ -294,28 +307,57 @@ impl AstNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl Ast for AstNode {
|
||||
fn render(&self) -> String {
|
||||
impl AstNode {
|
||||
fn render_into(&self, out: &mut String) {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.body.render(),
|
||||
Self::Branch { role, children } =>
|
||||
render_branch(*role, children),
|
||||
Self::Leaf(leaf) => leaf.body.render_into(out),
|
||||
Self::Branch { role, children } => {
|
||||
out.push_str(&format!("<|im_start|>{}\n", role.as_str()));
|
||||
for child in children {
|
||||
child.render_into(out);
|
||||
}
|
||||
out.push_str("<|im_end|>\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn token_ids(&self) -> Vec<u32> {
|
||||
fn token_ids_into(&self, out: &mut Vec<u32>) {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.token_ids.clone(),
|
||||
Self::Branch { role, children } =>
|
||||
tokenizer::encode(&render_branch(*role, children)),
|
||||
Self::Leaf(leaf) => out.extend_from_slice(&leaf.token_ids),
|
||||
Self::Branch { role, children } => {
|
||||
out.push(tokenizer::IM_START);
|
||||
out.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||||
for child in children {
|
||||
child.token_ids_into(out);
|
||||
}
|
||||
out.push(tokenizer::IM_END);
|
||||
out.extend(tokenizer::encode("\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Ast for AstNode {
|
||||
fn render(&self) -> String {
|
||||
let mut s = String::new();
|
||||
self.render_into(&mut s);
|
||||
s
|
||||
}
|
||||
|
||||
fn token_ids(&self) -> Vec<u32> {
|
||||
let mut ids = Vec::new();
|
||||
self.token_ids_into(&mut ids);
|
||||
ids
|
||||
}
|
||||
|
||||
fn tokens(&self) -> usize {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.tokens(),
|
||||
Self::Branch { children, .. } =>
|
||||
children.iter().map(|c| c.tokens()).sum(),
|
||||
Self::Branch { role, children } => {
|
||||
1 + tokenizer::encode(&format!("{}\n", role.as_str())).len()
|
||||
+ children.iter().map(|c| c.tokens()).sum::<usize>()
|
||||
+ 1 + tokenizer::encode("\n").len()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -326,15 +368,6 @@ fn truncate_preview(s: &str, max: usize) -> String {
|
|||
if s.len() > max { format!("{}...", preview) } else { preview }
|
||||
}
|
||||
|
||||
fn render_branch(role: Role, children: &[AstNode]) -> String {
|
||||
let mut s = format!("<|im_start|>{}\n", role.as_str());
|
||||
for child in children {
|
||||
s.push_str(&child.render());
|
||||
}
|
||||
s.push_str("<|im_end|>\n");
|
||||
s
|
||||
}
|
||||
|
||||
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||||
let args: serde_json::Value = serde_json::from_str(args_json)
|
||||
.unwrap_or(serde_json::Value::Object(Default::default()));
|
||||
|
|
@ -933,25 +966,21 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
/// token_ids() must equal encode(render()) for all node types
|
||||
fn assert_token_roundtrip(node: &AstNode) {
|
||||
let rendered = node.render();
|
||||
let expected = tokenizer::encode(&rendered);
|
||||
let actual = node.token_ids();
|
||||
assert_eq!(actual, expected,
|
||||
"token_ids mismatch for rendered: {:?}", rendered);
|
||||
fn assert_token_invariants(node: &AstNode) {
|
||||
assert_eq!(node.tokens(), node.token_ids().len(),
|
||||
"tokens() != token_ids().len()");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_roundtrip_leaf_types() {
|
||||
if !init_tokenizer() { return; }
|
||||
|
||||
assert_token_roundtrip(&AstNode::system_msg("you are a helpful assistant"));
|
||||
assert_token_roundtrip(&AstNode::user_msg("what is 2+2?"));
|
||||
assert_token_roundtrip(&AstNode::tool_result("4"));
|
||||
assert_token_roundtrip(&AstNode::memory("identity", "I am Proof of Concept"));
|
||||
assert_token_roundtrip(&AstNode::dmn("check the memory store"));
|
||||
assert_token_roundtrip(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#));
|
||||
assert_token_invariants(&AstNode::system_msg("you are a helpful assistant"));
|
||||
assert_token_invariants(&AstNode::user_msg("what is 2+2?"));
|
||||
assert_token_invariants(&AstNode::tool_result("4"));
|
||||
assert_token_invariants(&AstNode::memory("identity", "I am Proof of Concept"));
|
||||
assert_token_invariants(&AstNode::dmn("check the memory store"));
|
||||
assert_token_invariants(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -963,7 +992,7 @@ mod tests {
|
|||
AstNode::tool_call("bash", r#"{"command":"pwd"}"#),
|
||||
AstNode::content("\nthat's the current directory"),
|
||||
]);
|
||||
assert_token_roundtrip(&node);
|
||||
assert_token_invariants(&node);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -994,26 +1023,19 @@ mod tests {
|
|||
ctx.push(Section::Identity, AstNode::memory("name", "Proof of Concept"));
|
||||
ctx.push(Section::Conversation, AstNode::user_msg("hi"));
|
||||
|
||||
let rendered = ctx.render();
|
||||
let expected = tokenizer::encode(&rendered);
|
||||
let actual = ctx.token_ids();
|
||||
assert_eq!(actual, expected);
|
||||
assert_eq!(ctx.tokens(), ctx.token_ids().len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parser_roundtrip_through_tokenizer() {
|
||||
if !init_tokenizer() { return; }
|
||||
|
||||
// Parse a response, render it, verify it matches the expected format
|
||||
let mut p = ResponseParser::new();
|
||||
p.feed("I'll check that for you");
|
||||
p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>");
|
||||
let node = p.finish();
|
||||
|
||||
// The assistant branch should tokenize to the same as encoding its render
|
||||
assert_token_roundtrip(&node);
|
||||
|
||||
// Token count should be nonzero (thinking is invisible but content + tool call are)
|
||||
assert_token_invariants(&node);
|
||||
assert!(node.tokens() > 0);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue