ResponseParser mutates AST directly, returns PendingToolCalls

The parser takes &mut ContextState on feed()/finish() and pushes
completed children (content, thinking, tool calls) directly into
the assistant branch. Only PendingToolCall handles are returned
to the caller for dispatch — the caller no longer manages AST
mutation.

Tests verify by reading back from ContextState after parsing.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-08 14:33:57 -04:00
parent 6139d43942
commit 648356ae40

View file

@ -111,9 +111,10 @@ pub trait Ast {
} }
pub struct ResponseParser { pub struct ResponseParser {
branch_idx: usize,
call_counter: u32,
buf: String, buf: String,
content_parts: Vec<String>, content_parts: Vec<String>,
children: Vec<AstNode>,
in_think: bool, in_think: bool,
think_buf: String, think_buf: String,
in_tool_call: bool, in_tool_call: bool,
@ -460,11 +461,14 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
} }
impl ResponseParser { impl ResponseParser {
pub fn new() -> Self { /// Create a parser that pushes children into the assistant branch
/// at `branch_idx` in the conversation section.
pub fn new(branch_idx: usize) -> Self {
Self { Self {
branch_idx,
call_counter: 0,
buf: String::new(), buf: String::new(),
content_parts: Vec::new(), content_parts: Vec::new(),
children: Vec::new(),
in_think: false, in_think: false,
think_buf: String::new(), think_buf: String::new(),
in_tool_call: false, in_tool_call: false,
@ -472,9 +476,10 @@ impl ResponseParser {
} }
} }
/// Feed a text chunk. Returns completed child nodes — the caller /// Feed a text chunk. Completed children are pushed directly into
/// pushes them into the assistant branch and dispatches any tool calls. /// the AST. Returns any tool calls that need dispatching.
pub fn feed(&mut self, text: &str) -> Vec<AstNode> { pub fn feed(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
let mut pending = Vec::new();
self.buf.push_str(text); self.buf.push_str(text);
loop { loop {
@ -484,7 +489,7 @@ impl ResponseParser {
self.think_buf.push_str(&self.buf[..end]); self.think_buf.push_str(&self.buf[..end]);
self.buf = self.buf[end + 8..].to_string(); self.buf = self.buf[end + 8..].to_string();
self.in_think = false; self.in_think = false;
self.children.push(AstNode::thinking(&self.think_buf)); self.push_child(ctx, AstNode::thinking(&self.think_buf));
self.think_buf.clear(); self.think_buf.clear();
continue; continue;
} }
@ -507,8 +512,14 @@ impl ResponseParser {
self.buf = self.buf[end + 12..].to_string(); self.buf = self.buf[end + 12..].to_string();
self.in_tool_call = false; self.in_tool_call = false;
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
self.flush_content(); self.flush_content(ctx);
self.children.push(AstNode::tool_call(name, args)); self.push_child(ctx, AstNode::tool_call(&name, &args));
self.call_counter += 1;
pending.push(PendingToolCall {
name,
arguments: args,
id: format!("call_{}", self.call_counter),
});
} }
self.tool_call_buf.clear(); self.tool_call_buf.clear();
continue; continue;
@ -541,11 +552,11 @@ impl ResponseParser {
} }
if self.buf[pos..].starts_with("<think>") { if self.buf[pos..].starts_with("<think>") {
self.buf = self.buf[pos + 7..].to_string(); self.buf = self.buf[pos + 7..].to_string();
self.flush_content(); self.flush_content(ctx);
self.in_think = true; self.in_think = true;
} else { } else {
self.buf = self.buf[pos + 11..].to_string(); self.buf = self.buf[pos + 11..].to_string();
self.flush_content(); self.flush_content(ctx);
self.in_tool_call = true; self.in_tool_call = true;
} }
continue; continue;
@ -562,25 +573,28 @@ impl ResponseParser {
} }
} }
self.children.drain(..).collect() pending
} }
fn flush_content(&mut self) { fn push_child(&self, ctx: &mut ContextState, child: AstNode) {
ctx.push_child(Section::Conversation, self.branch_idx, child);
}
fn flush_content(&mut self, ctx: &mut ContextState) {
if !self.content_parts.is_empty() { if !self.content_parts.is_empty() {
let text: String = self.content_parts.drain(..).collect(); let text: String = self.content_parts.drain(..).collect();
if !text.is_empty() { if !text.is_empty() {
self.children.push(AstNode::content(text)); self.push_child(ctx, AstNode::content(text));
} }
} }
} }
/// Flush remaining buffer and return any final children. /// Flush remaining buffer into the AST.
pub fn finish(mut self) -> Vec<AstNode> { pub fn finish(mut self, ctx: &mut ContextState) {
if !self.buf.is_empty() { if !self.buf.is_empty() {
self.content_parts.push(std::mem::take(&mut self.buf)); self.content_parts.push(std::mem::take(&mut self.buf));
} }
self.flush_content(); self.flush_content(ctx);
self.children
} }
/// Current display text (content accumulated since last drain). /// Current display text (content accumulated since last drain).
@ -812,26 +826,36 @@ mod tests {
// -- ResponseParser tests ------------------------------------------------- // -- ResponseParser tests -------------------------------------------------
/// Collect all children from feed + finish. /// Set up a ContextState with an assistant branch, run the parser,
fn parse_all(text: &str) -> Vec<AstNode> { /// return the children that were pushed into the branch.
let mut p = ResponseParser::new(); fn parse_into_ctx(chunks: &[&str]) -> (ContextState, Vec<PendingToolCall>) {
let mut all = p.feed(text); let mut ctx = ContextState::new();
all.extend(p.finish()); ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
all let mut p = ResponseParser::new(0);
let mut calls = Vec::new();
for chunk in chunks {
calls.extend(p.feed(chunk, &mut ctx));
}
p.finish(&mut ctx);
(ctx, calls)
}
fn assistant_children(ctx: &ContextState) -> &[AstNode] {
ctx.conversation()[0].children()
} }
#[test] #[test]
fn test_parser_plain_text() { fn test_parser_plain_text() {
let nodes = parse_all("hello world"); let (ctx, _) = parse_into_ctx(&["hello world"]);
let b = bodies(&nodes); let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 1); assert_eq!(b.len(), 1);
assert_content(b[0], "hello world"); assert_content(b[0], "hello world");
} }
#[test] #[test]
fn test_parser_thinking_then_content() { fn test_parser_thinking_then_content() {
let nodes = parse_all("<think>reasoning</think>answer"); let (ctx, _) = parse_into_ctx(&["<think>reasoning</think>answer"]);
let b = bodies(&nodes); let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 2); assert_eq!(b.len(), 2);
assert_thinking(b[0], "reasoning"); assert_thinking(b[0], "reasoning");
assert_content(b[1], "answer"); assert_content(b[1], "answer");
@ -839,11 +863,13 @@ mod tests {
#[test] #[test]
fn test_parser_tool_call() { fn test_parser_tool_call() {
let mut p = ResponseParser::new(); let (ctx, calls) = parse_into_ctx(&[
let children = p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>"); "<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>"
// Tool call returned immediately from feed ]);
assert_eq!(children.len(), 1); assert_eq!(calls.len(), 1);
let b = bodies(&children); assert_eq!(calls[0].name, "bash");
let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 1);
let args = assert_tool_call(b[0], "bash"); let args = assert_tool_call(b[0], "bash");
let args: serde_json::Value = serde_json::from_str(args).unwrap(); let args: serde_json::Value = serde_json::from_str(args).unwrap();
assert_eq!(args["command"], "ls"); assert_eq!(args["command"], "ls");
@ -851,12 +877,12 @@ mod tests {
#[test] #[test]
fn test_parser_content_then_tool_call_then_content() { fn test_parser_content_then_tool_call_then_content() {
let mut p = ResponseParser::new(); let (ctx, _) = parse_into_ctx(&[
let mut all = p.feed("before"); "before",
all.extend(p.feed("<tool_call>\n<function=bash>\n<parameter=command>pwd</parameter>\n</function>\n</tool_call>")); "<tool_call>\n<function=bash>\n<parameter=command>pwd</parameter>\n</function>\n</tool_call>",
all.extend(p.feed("after")); "after",
all.extend(p.finish()); ]);
let b = bodies(&all); let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 3); assert_eq!(b.len(), 3);
assert_content(b[0], "before"); assert_content(b[0], "before");
assert_tool_call(b[1], "bash"); assert_tool_call(b[1], "bash");
@ -866,13 +892,14 @@ mod tests {
#[test] #[test]
fn test_parser_incremental_feed() { fn test_parser_incremental_feed() {
let text = "<think>thought</think>response"; let text = "<think>thought</think>response";
let mut p = ResponseParser::new(); let mut ctx = ContextState::new();
let mut all = Vec::new(); ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
let mut p = ResponseParser::new(0);
for ch in text.chars() { for ch in text.chars() {
all.extend(p.feed(&ch.to_string())); p.feed(&ch.to_string(), &mut ctx);
} }
all.extend(p.finish()); p.finish(&mut ctx);
let b = bodies(&all); let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 2); assert_eq!(b.len(), 2);
assert_thinking(b[0], "thought"); assert_thinking(b[0], "thought");
assert_content(b[1], "response"); assert_content(b[1], "response");
@ -881,23 +908,16 @@ mod tests {
#[test] #[test]
fn test_parser_incremental_tool_call() { fn test_parser_incremental_tool_call() {
let text = "text<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>more"; let text = "text<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>more";
let mut p = ResponseParser::new(); let mut ctx = ContextState::new();
let mut all = Vec::new(); ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
let mut p = ResponseParser::new(0);
let mut tool_calls = 0; let mut tool_calls = 0;
for ch in text.chars() { for ch in text.chars() {
let children = p.feed(&ch.to_string()); tool_calls += p.feed(&ch.to_string(), &mut ctx).len();
for c in &children {
if let AstNode::Leaf(l) = c {
if matches!(l.body(), NodeBody::ToolCall { .. }) {
tool_calls += 1;
} }
} p.finish(&mut ctx);
}
all.extend(children);
}
all.extend(p.finish());
assert_eq!(tool_calls, 1); assert_eq!(tool_calls, 1);
let b = bodies(&all); let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 3); assert_eq!(b.len(), 3);
assert_content(b[0], "text"); assert_content(b[0], "text");
assert_tool_call(b[1], "bash"); assert_tool_call(b[1], "bash");
@ -906,12 +926,12 @@ mod tests {
#[test] #[test]
fn test_parser_thinking_tool_call_content() { fn test_parser_thinking_tool_call_content() {
let mut p = ResponseParser::new(); let (ctx, _) = parse_into_ctx(&[
let mut all = p.feed("<think>let me think</think>"); "<think>let me think</think>",
all.extend(p.feed("<tool_call>\n<function=read>\n<parameter=path>/etc/hosts</parameter>\n</function>\n</tool_call>")); "<tool_call>\n<function=read>\n<parameter=path>/etc/hosts</parameter>\n</function>\n</tool_call>",
all.extend(p.feed("here's what I found")); "here's what I found",
all.extend(p.finish()); ]);
let b = bodies(&all); let b = bodies(assistant_children(&ctx));
assert_eq!(b.len(), 3); assert_eq!(b.len(), 3);
assert_thinking(b[0], "let me think"); assert_thinking(b[0], "let me think");
assert_tool_call(b[1], "read"); assert_tool_call(b[1], "read");
@ -1049,14 +1069,12 @@ mod tests {
fn test_parser_roundtrip_through_tokenizer() { fn test_parser_roundtrip_through_tokenizer() {
if !init_tokenizer() { return; } if !init_tokenizer() { return; }
let mut p = ResponseParser::new(); let (ctx, _) = parse_into_ctx(&[
let mut all = p.feed("I'll check that for you"); "I'll check that for you",
all.extend(p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>")); "<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>",
all.extend(p.finish()); ]);
let node = &ctx.conversation()[0];
// Wrap in assistant branch to test full tokenization assert_token_invariants(node);
let node = AstNode::branch(Role::Assistant, all);
assert_token_invariants(&node);
assert!(node.tokens() > 0); assert!(node.tokens() > 0);
} }
} }