Parser consumes stream directly, yields tool calls via channel

ResponseParser::run() spawns a task that reads StreamTokens, parses
into the AST (locking context per token), and sends PendingToolCalls
through a channel. Returns (tool_rx, JoinHandle<Result>) — the turn
loop dispatches tool calls and awaits the handle for error checking.

Token IDs from vLLM are accumulated alongside text and stored directly
on AST leaves — no local re-encoding on the response path.

The turn loop no longer matches on individual stream events. It just
reads tool calls and dispatches them.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-08 16:32:00 -04:00
parent 0b9813431a
commit 2c401e24d6
3 changed files with 119 additions and 85 deletions

View file

@ -8,14 +8,13 @@
pub mod http; pub mod http;
use anyhow::Result;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use anyhow::Result;
use self::http::{HttpClient, HttpResponse};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use serde::Deserialize; use serde::Deserialize;
use http::{HttpClient, HttpResponse};
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct Usage { pub struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,

View file

@ -115,11 +115,15 @@ pub struct ResponseParser {
branch_idx: usize, branch_idx: usize,
call_counter: u32, call_counter: u32,
buf: String, buf: String,
buf_token_ids: Vec<u32>,
content_parts: Vec<String>, content_parts: Vec<String>,
content_token_ids: Vec<u32>,
in_think: bool, in_think: bool,
think_buf: String, think_buf: String,
think_token_ids: Vec<u32>,
in_tool_call: bool, in_tool_call: bool,
tool_call_buf: String, tool_call_buf: String,
tool_call_token_ids: Vec<u32>,
} }
impl Role { impl Role {
@ -462,36 +466,80 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
} }
impl ResponseParser { impl ResponseParser {
/// Create a parser that pushes children into the assistant branch
/// at `branch_idx` in the conversation section.
pub fn new(branch_idx: usize) -> Self { pub fn new(branch_idx: usize) -> Self {
Self { Self {
branch_idx, branch_idx,
call_counter: 0, call_counter: 0,
buf: String::new(), buf: String::new(),
buf_token_ids: Vec::new(),
content_parts: Vec::new(), content_parts: Vec::new(),
content_token_ids: Vec::new(),
in_think: false, in_think: false,
think_buf: String::new(), think_buf: String::new(),
think_token_ids: Vec::new(),
in_tool_call: false, in_tool_call: false,
tool_call_buf: String::new(), tool_call_buf: String::new(),
tool_call_token_ids: Vec::new(),
} }
} }
/// Feed a text chunk. Completed children are pushed directly into /// Consume a token stream, parse into the AST, yield tool calls.
/// the AST. Returns any tool calls that need dispatching. /// Spawns a background task. Returns a tool call receiver and a
pub fn feed(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> { /// join handle that resolves to Ok(()) or the stream error.
pub fn run(
self,
mut stream: tokio::sync::mpsc::UnboundedReceiver<super::api::StreamToken>,
agent: std::sync::Arc<super::Agent>,
) -> (
tokio::sync::mpsc::UnboundedReceiver<PendingToolCall>,
tokio::task::JoinHandle<anyhow::Result<()>>,
) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let handle = tokio::spawn(async move {
let mut parser = self;
while let Some(event) = stream.recv().await {
match event {
super::api::StreamToken::Token { text, id } => {
let mut ctx = agent.context.lock().await;
for call in parser.feed_token(&text, id, &mut ctx) {
let _ = tx.send(call);
}
}
super::api::StreamToken::Done { usage } => {
if let Some(u) = usage {
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
}
let mut ctx = agent.context.lock().await;
parser.finish(&mut ctx);
return Ok(());
}
super::api::StreamToken::Error(e) => {
return Err(anyhow::anyhow!("{}", e));
}
}
}
Ok(())
});
(rx, handle)
}
pub fn feed_token(&mut self, text: &str, token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
let mut pending = Vec::new(); let mut pending = Vec::new();
self.buf.push_str(text); self.buf.push_str(text);
self.buf_token_ids.push(token_id);
loop { loop {
if self.in_think { if self.in_think {
match self.buf.find("</think>") { match self.buf.find("</think>") {
Some(end) => { Some(end) => {
self.think_buf.push_str(&self.buf[..end]); self.think_buf.push_str(&self.buf[..end]);
// Token IDs: move all buffered IDs to think (approximate split)
self.think_token_ids.extend(self.buf_token_ids.drain(..));
self.buf = self.buf[end + 8..].to_string(); self.buf = self.buf[end + 8..].to_string();
self.in_think = false; self.in_think = false;
self.push_child(ctx, AstNode::thinking(&self.think_buf)); let text = std::mem::take(&mut self.think_buf);
self.think_buf.clear(); let ids = std::mem::take(&mut self.think_token_ids);
self.push_child_with_tokens(ctx, NodeBody::Thinking(text), ids);
continue; continue;
} }
None => { None => {
@ -500,6 +548,7 @@ impl ResponseParser {
let safe = self.buf.floor_char_boundary(safe); let safe = self.buf.floor_char_boundary(safe);
self.think_buf.push_str(&self.buf[..safe]); self.think_buf.push_str(&self.buf[..safe]);
self.buf = self.buf[safe..].to_string(); self.buf = self.buf[safe..].to_string();
// Keep token IDs in buf (lookahead)
} }
break; break;
} }
@ -510,10 +559,12 @@ impl ResponseParser {
match self.buf.find("</tool_call>") { match self.buf.find("</tool_call>") {
Some(end) => { Some(end) => {
self.tool_call_buf.push_str(&self.buf[..end]); self.tool_call_buf.push_str(&self.buf[..end]);
self.tool_call_token_ids.extend(self.buf_token_ids.drain(..));
self.buf = self.buf[end + 12..].to_string(); self.buf = self.buf[end + 12..].to_string();
self.in_tool_call = false; self.in_tool_call = false;
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
self.flush_content(ctx); self.flush_content(ctx);
// Tool calls get re-tokenized from structured data
self.push_child(ctx, AstNode::tool_call(&name, &args)); self.push_child(ctx, AstNode::tool_call(&name, &args));
self.call_counter += 1; self.call_counter += 1;
pending.push(PendingToolCall { pending.push(PendingToolCall {
@ -523,6 +574,7 @@ impl ResponseParser {
}); });
} }
self.tool_call_buf.clear(); self.tool_call_buf.clear();
self.tool_call_token_ids.clear();
continue; continue;
} }
None => { None => {
@ -551,6 +603,8 @@ impl ResponseParser {
if pos > 0 { if pos > 0 {
self.content_parts.push(self.buf[..pos].to_string()); self.content_parts.push(self.buf[..pos].to_string());
} }
// Move token IDs to content accumulator
self.content_token_ids.extend(self.buf_token_ids.drain(..));
if self.buf[pos..].starts_with("<think>") { if self.buf[pos..].starts_with("<think>") {
self.buf = self.buf[pos + 7..].to_string(); self.buf = self.buf[pos + 7..].to_string();
self.flush_content(ctx); self.flush_content(ctx);
@ -568,6 +622,7 @@ impl ResponseParser {
let safe = self.buf.floor_char_boundary(safe); let safe = self.buf.floor_char_boundary(safe);
self.content_parts.push(self.buf[..safe].to_string()); self.content_parts.push(self.buf[..safe].to_string());
self.buf = self.buf[safe..].to_string(); self.buf = self.buf[safe..].to_string();
// Keep token IDs in buf (lookahead)
} }
break; break;
} }
@ -581,27 +636,28 @@ impl ResponseParser {
ctx.push_child(Section::Conversation, self.branch_idx, child); ctx.push_child(Section::Conversation, self.branch_idx, child);
} }
fn push_child_with_tokens(&self, ctx: &mut ContextState, body: NodeBody, token_ids: Vec<u32>) {
let leaf = NodeLeaf { body, token_ids, timestamp: None };
ctx.push_child(Section::Conversation, self.branch_idx, AstNode::Leaf(leaf));
}
fn flush_content(&mut self, ctx: &mut ContextState) { fn flush_content(&mut self, ctx: &mut ContextState) {
if !self.content_parts.is_empty() { if !self.content_parts.is_empty() {
let text: String = self.content_parts.drain(..).collect(); let text: String = self.content_parts.drain(..).collect();
if !text.is_empty() { if !text.is_empty() {
self.push_child(ctx, AstNode::content(text)); let token_ids = std::mem::take(&mut self.content_token_ids);
self.push_child_with_tokens(ctx, NodeBody::Content(text), token_ids);
} }
} }
} }
/// Flush remaining buffer into the AST.
pub fn finish(mut self, ctx: &mut ContextState) { pub fn finish(mut self, ctx: &mut ContextState) {
if !self.buf.is_empty() { if !self.buf.is_empty() {
self.content_parts.push(std::mem::take(&mut self.buf)); self.content_parts.push(std::mem::take(&mut self.buf));
self.content_token_ids.extend(self.buf_token_ids.drain(..));
} }
self.flush_content(ctx); self.flush_content(ctx);
} }
/// Current display text (content accumulated since last drain).
pub fn display_content(&self) -> String {
self.content_parts.join("")
}
} }
impl ContextState { impl ContextState {
@ -838,7 +894,8 @@ mod tests {
let mut p = ResponseParser::new(0); let mut p = ResponseParser::new(0);
let mut calls = Vec::new(); let mut calls = Vec::new();
for chunk in chunks { for chunk in chunks {
calls.extend(p.feed(chunk, &mut ctx)); // Feed each chunk as a single token (id=0 for tests)
calls.extend(p.feed_token(chunk, 0, &mut ctx));
} }
p.finish(&mut ctx); p.finish(&mut ctx);
(ctx, calls) (ctx, calls)
@ -900,7 +957,7 @@ mod tests {
ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![])); ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
let mut p = ResponseParser::new(0); let mut p = ResponseParser::new(0);
for ch in text.chars() { for ch in text.chars() {
p.feed(&ch.to_string(), &mut ctx); p.feed_token(&ch.to_string(), 0, &mut ctx);
} }
p.finish(&mut ctx); p.finish(&mut ctx);
let b = bodies(assistant_children(&ctx)); let b = bodies(assistant_children(&ctx));
@ -917,7 +974,7 @@ mod tests {
let mut p = ResponseParser::new(0); let mut p = ResponseParser::new(0);
let mut tool_calls = 0; let mut tool_calls = 0;
for ch in text.chars() { for ch in text.chars() {
tool_calls += p.feed(&ch.to_string(), &mut ctx).len(); tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len();
} }
p.finish(&mut ctx); p.finish(&mut ctx);
assert_eq!(tool_calls, 1); assert_eq!(tool_calls, 1);

View file

@ -339,77 +339,55 @@ impl Agent {
AstNode::branch(Role::Assistant, vec![])); AstNode::branch(Role::Assistant, vec![]));
idx idx
}; };
let mut parser = ResponseParser::new(branch_idx);
let mut pending_calls: Vec<PendingToolCall> = Vec::new();
let mut had_content = false;
let mut stream_error: Option<String> = None;
// Stream loop — no lock held across I/O let parser = ResponseParser::new(branch_idx);
while let Some(event) = rx.recv().await { let (mut tool_rx, parser_handle) = parser.run(rx, agent.clone());
match event {
api::StreamToken::Token { text, id: _ } => { let mut pending_calls: Vec<PendingToolCall> = Vec::new();
had_content = true; while let Some(call) = tool_rx.recv().await {
let mut ctx = agent.context.lock().await; let call_clone = call.clone();
let calls = parser.feed(&text, &mut ctx); let agent_handle = agent.clone();
drop(ctx); let handle = tokio::spawn(async move {
for call in calls { let args: serde_json::Value =
let call_clone = call.clone(); serde_json::from_str(&call_clone.arguments).unwrap_or_default();
let agent_handle = agent.clone(); let output = tools::dispatch_with_agent(
let handle = tokio::spawn(async move { &call_clone.name, &args, Some(agent_handle),
let args: serde_json::Value = ).await;
serde_json::from_str(&call_clone.arguments).unwrap_or_default(); (call_clone, output)
let output = tools::dispatch_with_agent( });
&call_clone.name, &args, Some(agent_handle), active_tools.lock().unwrap().push(tools::ActiveToolCall {
).await; id: call.id.clone(),
(call_clone, output) name: call.name.clone(),
}); detail: call.arguments.clone(),
active_tools.lock().unwrap().push(tools::ActiveToolCall { started: std::time::Instant::now(),
id: call.id.clone(), background: false,
name: call.name.clone(), handle,
detail: call.arguments.clone(), });
started: std::time::Instant::now(), pending_calls.push(call);
background: false,
handle,
});
pending_calls.push(call);
}
}
api::StreamToken::Error(e) => {
stream_error = Some(e);
break;
}
api::StreamToken::Done { usage } => {
if let Some(u) = usage {
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
}
break;
}
}
} }
// Flush parser remainder // Check for stream/parse errors
parser.finish(&mut *agent.context.lock().await); match parser_handle.await {
Ok(Err(e)) => {
// Handle errors if context::is_context_overflow(&e) && overflow_retries < 2 {
if let Some(e) = stream_error { overflow_retries += 1;
let err = anyhow::anyhow!("{}", e); agent.state.lock().await.notify(
if context::is_context_overflow(&err) && overflow_retries < 2 { format!("context overflow — retrying ({}/2)", overflow_retries));
overflow_retries += 1; agent.compact().await;
agent.state.lock().await.notify(format!("context overflow — retrying ({}/2)", overflow_retries)); continue;
agent.compact().await; }
continue; return Err(e);
} }
if context::is_stream_error(&err) && empty_retries < 2 { Err(e) => return Err(anyhow::anyhow!("parser task panicked: {}", e)),
empty_retries += 1; Ok(Ok(())) => {}
agent.state.lock().await.notify(format!("stream error — retrying ({}/2)", empty_retries));
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
continue;
}
return Err(err);
} }
// Empty response — nudge and retry // Empty response — nudge and retry
if !had_content && pending_calls.is_empty() { let has_content = {
let ctx = agent.context.lock().await;
!ctx.conversation()[branch_idx].children().is_empty()
};
if !has_content && pending_calls.is_empty() {
if empty_retries < 2 { if empty_retries < 2 {
empty_retries += 1; empty_retries += 1;
agent.push_node(AstNode::user_msg( agent.push_node(AstNode::user_msg(