ResponseParser mutates AST directly, returns PendingToolCalls
The parser takes &mut ContextState on feed()/finish() and pushes completed children (content, thinking, tool calls) directly into the assistant branch. Only PendingToolCall handles are returned to the caller for dispatch — the caller no longer manages AST mutation. Tests verify by reading back from ContextState after parsing. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
6139d43942
commit
648356ae40
1 changed files with 89 additions and 71 deletions
|
|
@ -111,9 +111,10 @@ pub trait Ast {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ResponseParser {
|
pub struct ResponseParser {
|
||||||
|
branch_idx: usize,
|
||||||
|
call_counter: u32,
|
||||||
buf: String,
|
buf: String,
|
||||||
content_parts: Vec<String>,
|
content_parts: Vec<String>,
|
||||||
children: Vec<AstNode>,
|
|
||||||
in_think: bool,
|
in_think: bool,
|
||||||
think_buf: String,
|
think_buf: String,
|
||||||
in_tool_call: bool,
|
in_tool_call: bool,
|
||||||
|
|
@ -460,11 +461,14 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseParser {
|
impl ResponseParser {
|
||||||
pub fn new() -> Self {
|
/// Create a parser that pushes children into the assistant branch
|
||||||
|
/// at `branch_idx` in the conversation section.
|
||||||
|
pub fn new(branch_idx: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
branch_idx,
|
||||||
|
call_counter: 0,
|
||||||
buf: String::new(),
|
buf: String::new(),
|
||||||
content_parts: Vec::new(),
|
content_parts: Vec::new(),
|
||||||
children: Vec::new(),
|
|
||||||
in_think: false,
|
in_think: false,
|
||||||
think_buf: String::new(),
|
think_buf: String::new(),
|
||||||
in_tool_call: false,
|
in_tool_call: false,
|
||||||
|
|
@ -472,9 +476,10 @@ impl ResponseParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Feed a text chunk. Returns completed child nodes — the caller
|
/// Feed a text chunk. Completed children are pushed directly into
|
||||||
/// pushes them into the assistant branch and dispatches any tool calls.
|
/// the AST. Returns any tool calls that need dispatching.
|
||||||
pub fn feed(&mut self, text: &str) -> Vec<AstNode> {
|
pub fn feed(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||||
|
let mut pending = Vec::new();
|
||||||
self.buf.push_str(text);
|
self.buf.push_str(text);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
|
@ -484,7 +489,7 @@ impl ResponseParser {
|
||||||
self.think_buf.push_str(&self.buf[..end]);
|
self.think_buf.push_str(&self.buf[..end]);
|
||||||
self.buf = self.buf[end + 8..].to_string();
|
self.buf = self.buf[end + 8..].to_string();
|
||||||
self.in_think = false;
|
self.in_think = false;
|
||||||
self.children.push(AstNode::thinking(&self.think_buf));
|
self.push_child(ctx, AstNode::thinking(&self.think_buf));
|
||||||
self.think_buf.clear();
|
self.think_buf.clear();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -507,8 +512,14 @@ impl ResponseParser {
|
||||||
self.buf = self.buf[end + 12..].to_string();
|
self.buf = self.buf[end + 12..].to_string();
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
||||||
self.flush_content();
|
self.flush_content(ctx);
|
||||||
self.children.push(AstNode::tool_call(name, args));
|
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||||
|
self.call_counter += 1;
|
||||||
|
pending.push(PendingToolCall {
|
||||||
|
name,
|
||||||
|
arguments: args,
|
||||||
|
id: format!("call_{}", self.call_counter),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
self.tool_call_buf.clear();
|
self.tool_call_buf.clear();
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -541,11 +552,11 @@ impl ResponseParser {
|
||||||
}
|
}
|
||||||
if self.buf[pos..].starts_with("<think>") {
|
if self.buf[pos..].starts_with("<think>") {
|
||||||
self.buf = self.buf[pos + 7..].to_string();
|
self.buf = self.buf[pos + 7..].to_string();
|
||||||
self.flush_content();
|
self.flush_content(ctx);
|
||||||
self.in_think = true;
|
self.in_think = true;
|
||||||
} else {
|
} else {
|
||||||
self.buf = self.buf[pos + 11..].to_string();
|
self.buf = self.buf[pos + 11..].to_string();
|
||||||
self.flush_content();
|
self.flush_content(ctx);
|
||||||
self.in_tool_call = true;
|
self.in_tool_call = true;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -562,25 +573,28 @@ impl ResponseParser {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.children.drain(..).collect()
|
pending
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flush_content(&mut self) {
|
fn push_child(&self, ctx: &mut ContextState, child: AstNode) {
|
||||||
|
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush_content(&mut self, ctx: &mut ContextState) {
|
||||||
if !self.content_parts.is_empty() {
|
if !self.content_parts.is_empty() {
|
||||||
let text: String = self.content_parts.drain(..).collect();
|
let text: String = self.content_parts.drain(..).collect();
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
self.children.push(AstNode::content(text));
|
self.push_child(ctx, AstNode::content(text));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flush remaining buffer and return any final children.
|
/// Flush remaining buffer into the AST.
|
||||||
pub fn finish(mut self) -> Vec<AstNode> {
|
pub fn finish(mut self, ctx: &mut ContextState) {
|
||||||
if !self.buf.is_empty() {
|
if !self.buf.is_empty() {
|
||||||
self.content_parts.push(std::mem::take(&mut self.buf));
|
self.content_parts.push(std::mem::take(&mut self.buf));
|
||||||
}
|
}
|
||||||
self.flush_content();
|
self.flush_content(ctx);
|
||||||
self.children
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Current display text (content accumulated since last drain).
|
/// Current display text (content accumulated since last drain).
|
||||||
|
|
@ -812,26 +826,36 @@ mod tests {
|
||||||
|
|
||||||
// -- ResponseParser tests -------------------------------------------------
|
// -- ResponseParser tests -------------------------------------------------
|
||||||
|
|
||||||
/// Collect all children from feed + finish.
|
/// Set up a ContextState with an assistant branch, run the parser,
|
||||||
fn parse_all(text: &str) -> Vec<AstNode> {
|
/// return the children that were pushed into the branch.
|
||||||
let mut p = ResponseParser::new();
|
fn parse_into_ctx(chunks: &[&str]) -> (ContextState, Vec<PendingToolCall>) {
|
||||||
let mut all = p.feed(text);
|
let mut ctx = ContextState::new();
|
||||||
all.extend(p.finish());
|
ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||||||
all
|
let mut p = ResponseParser::new(0);
|
||||||
|
let mut calls = Vec::new();
|
||||||
|
for chunk in chunks {
|
||||||
|
calls.extend(p.feed(chunk, &mut ctx));
|
||||||
|
}
|
||||||
|
p.finish(&mut ctx);
|
||||||
|
(ctx, calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assistant_children(ctx: &ContextState) -> &[AstNode] {
|
||||||
|
ctx.conversation()[0].children()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_plain_text() {
|
fn test_parser_plain_text() {
|
||||||
let nodes = parse_all("hello world");
|
let (ctx, _) = parse_into_ctx(&["hello world"]);
|
||||||
let b = bodies(&nodes);
|
let b = bodies(assistant_children(&ctx));
|
||||||
assert_eq!(b.len(), 1);
|
assert_eq!(b.len(), 1);
|
||||||
assert_content(b[0], "hello world");
|
assert_content(b[0], "hello world");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_thinking_then_content() {
|
fn test_parser_thinking_then_content() {
|
||||||
let nodes = parse_all("<think>reasoning</think>answer");
|
let (ctx, _) = parse_into_ctx(&["<think>reasoning</think>answer"]);
|
||||||
let b = bodies(&nodes);
|
let b = bodies(assistant_children(&ctx));
|
||||||
assert_eq!(b.len(), 2);
|
assert_eq!(b.len(), 2);
|
||||||
assert_thinking(b[0], "reasoning");
|
assert_thinking(b[0], "reasoning");
|
||||||
assert_content(b[1], "answer");
|
assert_content(b[1], "answer");
|
||||||
|
|
@ -839,11 +863,13 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_tool_call() {
|
fn test_parser_tool_call() {
|
||||||
let mut p = ResponseParser::new();
|
let (ctx, calls) = parse_into_ctx(&[
|
||||||
let children = p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>");
|
"<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>"
|
||||||
// Tool call returned immediately from feed
|
]);
|
||||||
assert_eq!(children.len(), 1);
|
assert_eq!(calls.len(), 1);
|
||||||
let b = bodies(&children);
|
assert_eq!(calls[0].name, "bash");
|
||||||
|
let b = bodies(assistant_children(&ctx));
|
||||||
|
assert_eq!(b.len(), 1);
|
||||||
let args = assert_tool_call(b[0], "bash");
|
let args = assert_tool_call(b[0], "bash");
|
||||||
let args: serde_json::Value = serde_json::from_str(args).unwrap();
|
let args: serde_json::Value = serde_json::from_str(args).unwrap();
|
||||||
assert_eq!(args["command"], "ls");
|
assert_eq!(args["command"], "ls");
|
||||||
|
|
@ -851,12 +877,12 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_content_then_tool_call_then_content() {
|
fn test_parser_content_then_tool_call_then_content() {
|
||||||
let mut p = ResponseParser::new();
|
let (ctx, _) = parse_into_ctx(&[
|
||||||
let mut all = p.feed("before");
|
"before",
|
||||||
all.extend(p.feed("<tool_call>\n<function=bash>\n<parameter=command>pwd</parameter>\n</function>\n</tool_call>"));
|
"<tool_call>\n<function=bash>\n<parameter=command>pwd</parameter>\n</function>\n</tool_call>",
|
||||||
all.extend(p.feed("after"));
|
"after",
|
||||||
all.extend(p.finish());
|
]);
|
||||||
let b = bodies(&all);
|
let b = bodies(assistant_children(&ctx));
|
||||||
assert_eq!(b.len(), 3);
|
assert_eq!(b.len(), 3);
|
||||||
assert_content(b[0], "before");
|
assert_content(b[0], "before");
|
||||||
assert_tool_call(b[1], "bash");
|
assert_tool_call(b[1], "bash");
|
||||||
|
|
@ -866,13 +892,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_incremental_feed() {
|
fn test_parser_incremental_feed() {
|
||||||
let text = "<think>thought</think>response";
|
let text = "<think>thought</think>response";
|
||||||
let mut p = ResponseParser::new();
|
let mut ctx = ContextState::new();
|
||||||
let mut all = Vec::new();
|
ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||||||
|
let mut p = ResponseParser::new(0);
|
||||||
for ch in text.chars() {
|
for ch in text.chars() {
|
||||||
all.extend(p.feed(&ch.to_string()));
|
p.feed(&ch.to_string(), &mut ctx);
|
||||||
}
|
}
|
||||||
all.extend(p.finish());
|
p.finish(&mut ctx);
|
||||||
let b = bodies(&all);
|
let b = bodies(assistant_children(&ctx));
|
||||||
assert_eq!(b.len(), 2);
|
assert_eq!(b.len(), 2);
|
||||||
assert_thinking(b[0], "thought");
|
assert_thinking(b[0], "thought");
|
||||||
assert_content(b[1], "response");
|
assert_content(b[1], "response");
|
||||||
|
|
@ -881,23 +908,16 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_incremental_tool_call() {
|
fn test_parser_incremental_tool_call() {
|
||||||
let text = "text<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>more";
|
let text = "text<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>more";
|
||||||
let mut p = ResponseParser::new();
|
let mut ctx = ContextState::new();
|
||||||
let mut all = Vec::new();
|
ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||||||
|
let mut p = ResponseParser::new(0);
|
||||||
let mut tool_calls = 0;
|
let mut tool_calls = 0;
|
||||||
for ch in text.chars() {
|
for ch in text.chars() {
|
||||||
let children = p.feed(&ch.to_string());
|
tool_calls += p.feed(&ch.to_string(), &mut ctx).len();
|
||||||
for c in &children {
|
|
||||||
if let AstNode::Leaf(l) = c {
|
|
||||||
if matches!(l.body(), NodeBody::ToolCall { .. }) {
|
|
||||||
tool_calls += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
all.extend(children);
|
|
||||||
}
|
}
|
||||||
all.extend(p.finish());
|
p.finish(&mut ctx);
|
||||||
assert_eq!(tool_calls, 1);
|
assert_eq!(tool_calls, 1);
|
||||||
let b = bodies(&all);
|
let b = bodies(assistant_children(&ctx));
|
||||||
assert_eq!(b.len(), 3);
|
assert_eq!(b.len(), 3);
|
||||||
assert_content(b[0], "text");
|
assert_content(b[0], "text");
|
||||||
assert_tool_call(b[1], "bash");
|
assert_tool_call(b[1], "bash");
|
||||||
|
|
@ -906,12 +926,12 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parser_thinking_tool_call_content() {
|
fn test_parser_thinking_tool_call_content() {
|
||||||
let mut p = ResponseParser::new();
|
let (ctx, _) = parse_into_ctx(&[
|
||||||
let mut all = p.feed("<think>let me think</think>");
|
"<think>let me think</think>",
|
||||||
all.extend(p.feed("<tool_call>\n<function=read>\n<parameter=path>/etc/hosts</parameter>\n</function>\n</tool_call>"));
|
"<tool_call>\n<function=read>\n<parameter=path>/etc/hosts</parameter>\n</function>\n</tool_call>",
|
||||||
all.extend(p.feed("here's what I found"));
|
"here's what I found",
|
||||||
all.extend(p.finish());
|
]);
|
||||||
let b = bodies(&all);
|
let b = bodies(assistant_children(&ctx));
|
||||||
assert_eq!(b.len(), 3);
|
assert_eq!(b.len(), 3);
|
||||||
assert_thinking(b[0], "let me think");
|
assert_thinking(b[0], "let me think");
|
||||||
assert_tool_call(b[1], "read");
|
assert_tool_call(b[1], "read");
|
||||||
|
|
@ -1049,14 +1069,12 @@ mod tests {
|
||||||
fn test_parser_roundtrip_through_tokenizer() {
|
fn test_parser_roundtrip_through_tokenizer() {
|
||||||
if !init_tokenizer() { return; }
|
if !init_tokenizer() { return; }
|
||||||
|
|
||||||
let mut p = ResponseParser::new();
|
let (ctx, _) = parse_into_ctx(&[
|
||||||
let mut all = p.feed("I'll check that for you");
|
"I'll check that for you",
|
||||||
all.extend(p.feed("<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>"));
|
"<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>",
|
||||||
all.extend(p.finish());
|
]);
|
||||||
|
let node = &ctx.conversation()[0];
|
||||||
// Wrap in assistant branch to test full tokenization
|
assert_token_invariants(node);
|
||||||
let node = AstNode::branch(Role::Assistant, all);
|
|
||||||
assert_token_invariants(&node);
|
|
||||||
assert!(node.tokens() > 0);
|
assert!(node.tokens() > 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue