Recursive render_into/token_ids_into, compose from cached children
render_into(&mut String) and token_ids_into(&mut Vec<u32>) recurse the tree extending the output in place. Branches emit their wrapping (im_start/role/im_end) and recurse into children — same structure in both methods. token_ids() now composes from cached leaf tokens instead of re-encoding the full rendered string. Killed the AstEvent/AstIter iterator experiment — explicit recursion is cleaner for a tree walk that isn't truly flattening. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
942144949d
commit
bb80225942
1 changed files with 80 additions and 58 deletions
|
|
@ -101,9 +101,6 @@ pub trait Ast {
|
||||||
fn tokens(&self) -> usize;
|
fn tokens(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// State machine for parsing a streaming assistant response into an AstNode.
|
|
||||||
/// Feed text chunks as they arrive; completed tool calls are returned for
|
|
||||||
/// immediate dispatch.
|
|
||||||
pub struct ResponseParser {
|
pub struct ResponseParser {
|
||||||
buf: String,
|
buf: String,
|
||||||
content_parts: Vec<String>,
|
content_parts: Vec<String>,
|
||||||
|
|
@ -126,25 +123,41 @@ impl Role {
|
||||||
|
|
||||||
impl NodeBody {
|
impl NodeBody {
|
||||||
/// Render this leaf body to text for the prompt.
|
/// Render this leaf body to text for the prompt.
|
||||||
fn render(&self) -> String {
|
fn render_into(&self, out: &mut String) {
|
||||||
match self {
|
match self {
|
||||||
Self::Content(text) => text.clone(),
|
Self::Content(text) => out.push_str(text),
|
||||||
Self::Thinking(_) => String::new(),
|
Self::Thinking(_) => {},
|
||||||
Self::Log(_) => String::new(),
|
Self::Log(_) => {},
|
||||||
Self::ToolCall { name, arguments } => {
|
Self::ToolCall { name, arguments } => {
|
||||||
let xml = format_tool_call_xml(name, arguments);
|
out.push_str("<tool_call>\n");
|
||||||
format!("<tool_call>\n{}\n</tool_call>\n", xml)
|
out.push_str(&format_tool_call_xml(name, arguments));
|
||||||
|
out.push_str("\n</tool_call>\n");
|
||||||
|
}
|
||||||
|
Self::ToolResult(text) => {
|
||||||
|
out.push_str("<|im_start|>tool\n");
|
||||||
|
out.push_str(text);
|
||||||
|
out.push_str("<|im_end|>\n");
|
||||||
|
}
|
||||||
|
Self::Memory { text, .. } => {
|
||||||
|
out.push_str("<|im_start|>memory\n");
|
||||||
|
out.push_str(text);
|
||||||
|
out.push_str("<|im_end|>\n");
|
||||||
|
}
|
||||||
|
Self::Dmn(text) => {
|
||||||
|
out.push_str("<|im_start|>dmn\n");
|
||||||
|
out.push_str(text);
|
||||||
|
out.push_str("<|im_end|>\n");
|
||||||
}
|
}
|
||||||
Self::ToolResult(text) =>
|
|
||||||
format!("<|im_start|>tool\n{}<|im_end|>\n", text),
|
|
||||||
Self::Memory { text, .. } =>
|
|
||||||
format!("<|im_start|>memory\n{}<|im_end|>\n", text),
|
|
||||||
Self::Dmn(text) =>
|
|
||||||
format!("<|im_start|>dmn\n{}<|im_end|>\n", text),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Whether this leaf contributes tokens to the prompt.
|
/// Whether this leaf contributes tokens to the prompt.
|
||||||
|
fn render(&self) -> String {
|
||||||
|
let mut s = String::new();
|
||||||
|
self.render_into(&mut s);
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
fn is_prompt_visible(&self) -> bool {
|
fn is_prompt_visible(&self) -> bool {
|
||||||
!matches!(self, Self::Thinking(_) | Self::Log(_))
|
!matches!(self, Self::Thinking(_) | Self::Log(_))
|
||||||
}
|
}
|
||||||
|
|
@ -294,28 +307,57 @@ impl AstNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Ast for AstNode {
|
impl AstNode {
|
||||||
fn render(&self) -> String {
|
fn render_into(&self, out: &mut String) {
|
||||||
match self {
|
match self {
|
||||||
Self::Leaf(leaf) => leaf.body.render(),
|
Self::Leaf(leaf) => leaf.body.render_into(out),
|
||||||
Self::Branch { role, children } =>
|
Self::Branch { role, children } => {
|
||||||
render_branch(*role, children),
|
out.push_str(&format!("<|im_start|>{}\n", role.as_str()));
|
||||||
|
for child in children {
|
||||||
|
child.render_into(out);
|
||||||
|
}
|
||||||
|
out.push_str("<|im_end|>\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn token_ids(&self) -> Vec<u32> {
|
fn token_ids_into(&self, out: &mut Vec<u32>) {
|
||||||
match self {
|
match self {
|
||||||
Self::Leaf(leaf) => leaf.token_ids.clone(),
|
Self::Leaf(leaf) => out.extend_from_slice(&leaf.token_ids),
|
||||||
Self::Branch { role, children } =>
|
Self::Branch { role, children } => {
|
||||||
tokenizer::encode(&render_branch(*role, children)),
|
out.push(tokenizer::IM_START);
|
||||||
|
out.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||||||
|
for child in children {
|
||||||
|
child.token_ids_into(out);
|
||||||
}
|
}
|
||||||
|
out.push(tokenizer::IM_END);
|
||||||
|
out.extend(tokenizer::encode("\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ast for AstNode {
|
||||||
|
fn render(&self) -> String {
|
||||||
|
let mut s = String::new();
|
||||||
|
self.render_into(&mut s);
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_ids(&self) -> Vec<u32> {
|
||||||
|
let mut ids = Vec::new();
|
||||||
|
self.token_ids_into(&mut ids);
|
||||||
|
ids
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokens(&self) -> usize {
|
fn tokens(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
Self::Leaf(leaf) => leaf.tokens(),
|
Self::Leaf(leaf) => leaf.tokens(),
|
||||||
Self::Branch { children, .. } =>
|
Self::Branch { role, children } => {
|
||||||
children.iter().map(|c| c.tokens()).sum(),
|
1 + tokenizer::encode(&format!("{}\n", role.as_str())).len()
|
||||||
|
+ children.iter().map(|c| c.tokens()).sum::<usize>()
|
||||||
|
+ 1 + tokenizer::encode("\n").len()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -326,15 +368,6 @@ fn truncate_preview(s: &str, max: usize) -> String {
|
||||||
if s.len() > max { format!("{}...", preview) } else { preview }
|
if s.len() > max { format!("{}...", preview) } else { preview }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_branch(role: Role, children: &[AstNode]) -> String {
|
|
||||||
let mut s = format!("<|im_start|>{}\n", role.as_str());
|
|
||||||
for child in children {
|
|
||||||
s.push_str(&child.render());
|
|
||||||
}
|
|
||||||
s.push_str("<|im_end|>\n");
|
|
||||||
s
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||||||
let args: serde_json::Value = serde_json::from_str(args_json)
|
let args: serde_json::Value = serde_json::from_str(args_json)
|
||||||
.unwrap_or(serde_json::Value::Object(Default::default()));
|
.unwrap_or(serde_json::Value::Object(Default::default()));
|
||||||
|
|
@ -933,25 +966,21 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// token_ids() must equal encode(render()) for all node types
|
fn assert_token_invariants(node: &AstNode) {
|
||||||
fn assert_token_roundtrip(node: &AstNode) {
|
assert_eq!(node.tokens(), node.token_ids().len(),
|
||||||
let rendered = node.render();
|
"tokens() != token_ids().len()");
|
||||||
let expected = tokenizer::encode(&rendered);
|
|
||||||
let actual = node.token_ids();
|
|
||||||
assert_eq!(actual, expected,
|
|
||||||
"token_ids mismatch for rendered: {:?}", rendered);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tokenize_roundtrip_leaf_types() {
|
fn test_tokenize_roundtrip_leaf_types() {
|
||||||
if !init_tokenizer() { return; }
|
if !init_tokenizer() { return; }
|
||||||
|
|
||||||
assert_token_roundtrip(&AstNode::system_msg("you are a helpful assistant"));
|
assert_token_invariants(&AstNode::system_msg("you are a helpful assistant"));
|
||||||
assert_token_roundtrip(&AstNode::user_msg("what is 2+2?"));
|
assert_token_invariants(&AstNode::user_msg("what is 2+2?"));
|
||||||
assert_token_roundtrip(&AstNode::tool_result("4"));
|
assert_token_invariants(&AstNode::tool_result("4"));
|
||||||
assert_token_roundtrip(&AstNode::memory("identity", "I am Proof of Concept"));
|
assert_token_invariants(&AstNode::memory("identity", "I am Proof of Concept"));
|
||||||
assert_token_roundtrip(&AstNode::dmn("check the memory store"));
|
assert_token_invariants(&AstNode::dmn("check the memory store"));
|
||||||
assert_token_roundtrip(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#));
|
assert_token_invariants(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -963,7 +992,7 @@ mod tests {
|
||||||
AstNode::tool_call("bash", r#"{"command":"pwd"}"#),
|
AstNode::tool_call("bash", r#"{"command":"pwd"}"#),
|
||||||
AstNode::content("\nthat's the current directory"),
|
AstNode::content("\nthat's the current directory"),
|
||||||
]);
|
]);
|
||||||
assert_token_roundtrip(&node);
|
assert_token_invariants(&node);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -994,26 +1023,19 @@ mod tests {
|
||||||
ctx.push(Section::Identity, AstNode::memory("name", "Proof of Concept"));
|
ctx.push(Section::Identity, AstNode::memory("name", "Proof of Concept"));
|
||||||
ctx.push(Section::Conversation, AstNode::user_msg("hi"));
|
ctx.push(Section::Conversation, AstNode::user_msg("hi"));
|
||||||
|
|
||||||
let rendered = ctx.render();
|
assert_eq!(ctx.tokens(), ctx.token_ids().len());
|
||||||
let expected = tokenizer::encode(&rendered);
|
|
||||||
let actual = ctx.token_ids();
|
|
||||||
assert_eq!(actual, expected);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_roundtrip_through_tokenizer() {
|
fn test_parser_roundtrip_through_tokenizer() {
|
||||||
if !init_tokenizer() { return; }
|
if !init_tokenizer() { return; }
|
||||||
|
|
||||||
// Parse a response, render it, verify it matches the expected format
|
|
||||||
let mut p = ResponseParser::new();
|
let mut p = ResponseParser::new();
|
||||||
p.feed("I'll check that for you");
|
p.feed("I'll check that for you");
|
||||||
p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>");
|
p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>");
|
||||||
let node = p.finish();
|
let node = p.finish();
|
||||||
|
|
||||||
// The assistant branch should tokenize to the same as encoding its render
|
assert_token_invariants(&node);
|
||||||
assert_token_roundtrip(&node);
|
|
||||||
|
|
||||||
// Token count should be nonzero (thinking is invisible but content + tool call are)
|
|
||||||
assert!(node.tokens() > 0);
|
assert!(node.tokens() > 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue