Parser consumes stream directly, yields tool calls via channel
ResponseParser::run() spawns a task that reads StreamTokens, parses into the AST (locking context per token), and sends PendingToolCalls through a channel. Returns (tool_rx, JoinHandle<Result>) — the turn loop dispatches tool calls and awaits the handle for error checking. Token IDs from vLLM are accumulated alongside text and stored directly on AST leaves — no local re-encoding on the response path. The turn loop no longer matches on individual stream events. It just reads tool calls and dispatches them. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
0b9813431a
commit
2c401e24d6
3 changed files with 119 additions and 85 deletions
|
|
@ -8,14 +8,13 @@
|
||||||
|
|
||||||
pub mod http;
|
pub mod http;
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use anyhow::Result;
|
||||||
use self::http::{HttpClient, HttpResponse};
|
|
||||||
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use http::{HttpClient, HttpResponse};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct Usage {
|
pub struct Usage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
|
|
|
||||||
|
|
@ -115,11 +115,15 @@ pub struct ResponseParser {
|
||||||
branch_idx: usize,
|
branch_idx: usize,
|
||||||
call_counter: u32,
|
call_counter: u32,
|
||||||
buf: String,
|
buf: String,
|
||||||
|
buf_token_ids: Vec<u32>,
|
||||||
content_parts: Vec<String>,
|
content_parts: Vec<String>,
|
||||||
|
content_token_ids: Vec<u32>,
|
||||||
in_think: bool,
|
in_think: bool,
|
||||||
think_buf: String,
|
think_buf: String,
|
||||||
|
think_token_ids: Vec<u32>,
|
||||||
in_tool_call: bool,
|
in_tool_call: bool,
|
||||||
tool_call_buf: String,
|
tool_call_buf: String,
|
||||||
|
tool_call_token_ids: Vec<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Role {
|
impl Role {
|
||||||
|
|
@ -462,36 +466,80 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseParser {
|
impl ResponseParser {
|
||||||
/// Create a parser that pushes children into the assistant branch
|
|
||||||
/// at `branch_idx` in the conversation section.
|
|
||||||
pub fn new(branch_idx: usize) -> Self {
|
pub fn new(branch_idx: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
branch_idx,
|
branch_idx,
|
||||||
call_counter: 0,
|
call_counter: 0,
|
||||||
buf: String::new(),
|
buf: String::new(),
|
||||||
|
buf_token_ids: Vec::new(),
|
||||||
content_parts: Vec::new(),
|
content_parts: Vec::new(),
|
||||||
|
content_token_ids: Vec::new(),
|
||||||
in_think: false,
|
in_think: false,
|
||||||
think_buf: String::new(),
|
think_buf: String::new(),
|
||||||
|
think_token_ids: Vec::new(),
|
||||||
in_tool_call: false,
|
in_tool_call: false,
|
||||||
tool_call_buf: String::new(),
|
tool_call_buf: String::new(),
|
||||||
|
tool_call_token_ids: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Feed a text chunk. Completed children are pushed directly into
|
/// Consume a token stream, parse into the AST, yield tool calls.
|
||||||
/// the AST. Returns any tool calls that need dispatching.
|
/// Spawns a background task. Returns a tool call receiver and a
|
||||||
pub fn feed(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
/// join handle that resolves to Ok(()) or the stream error.
|
||||||
|
pub fn run(
|
||||||
|
self,
|
||||||
|
mut stream: tokio::sync::mpsc::UnboundedReceiver<super::api::StreamToken>,
|
||||||
|
agent: std::sync::Arc<super::Agent>,
|
||||||
|
) -> (
|
||||||
|
tokio::sync::mpsc::UnboundedReceiver<PendingToolCall>,
|
||||||
|
tokio::task::JoinHandle<anyhow::Result<()>>,
|
||||||
|
) {
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let mut parser = self;
|
||||||
|
while let Some(event) = stream.recv().await {
|
||||||
|
match event {
|
||||||
|
super::api::StreamToken::Token { text, id } => {
|
||||||
|
let mut ctx = agent.context.lock().await;
|
||||||
|
for call in parser.feed_token(&text, id, &mut ctx) {
|
||||||
|
let _ = tx.send(call);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
super::api::StreamToken::Done { usage } => {
|
||||||
|
if let Some(u) = usage {
|
||||||
|
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
|
||||||
|
}
|
||||||
|
let mut ctx = agent.context.lock().await;
|
||||||
|
parser.finish(&mut ctx);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
super::api::StreamToken::Error(e) => {
|
||||||
|
return Err(anyhow::anyhow!("{}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
(rx, handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn feed_token(&mut self, text: &str, token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||||
let mut pending = Vec::new();
|
let mut pending = Vec::new();
|
||||||
self.buf.push_str(text);
|
self.buf.push_str(text);
|
||||||
|
self.buf_token_ids.push(token_id);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if self.in_think {
|
if self.in_think {
|
||||||
match self.buf.find("</think>") {
|
match self.buf.find("</think>") {
|
||||||
Some(end) => {
|
Some(end) => {
|
||||||
self.think_buf.push_str(&self.buf[..end]);
|
self.think_buf.push_str(&self.buf[..end]);
|
||||||
|
// Token IDs: move all buffered IDs to think (approximate split)
|
||||||
|
self.think_token_ids.extend(self.buf_token_ids.drain(..));
|
||||||
self.buf = self.buf[end + 8..].to_string();
|
self.buf = self.buf[end + 8..].to_string();
|
||||||
self.in_think = false;
|
self.in_think = false;
|
||||||
self.push_child(ctx, AstNode::thinking(&self.think_buf));
|
let text = std::mem::take(&mut self.think_buf);
|
||||||
self.think_buf.clear();
|
let ids = std::mem::take(&mut self.think_token_ids);
|
||||||
|
self.push_child_with_tokens(ctx, NodeBody::Thinking(text), ids);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
|
|
@ -500,6 +548,7 @@ impl ResponseParser {
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
let safe = self.buf.floor_char_boundary(safe);
|
||||||
self.think_buf.push_str(&self.buf[..safe]);
|
self.think_buf.push_str(&self.buf[..safe]);
|
||||||
self.buf = self.buf[safe..].to_string();
|
self.buf = self.buf[safe..].to_string();
|
||||||
|
// Keep token IDs in buf (lookahead)
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -510,10 +559,12 @@ impl ResponseParser {
|
||||||
match self.buf.find("</tool_call>") {
|
match self.buf.find("</tool_call>") {
|
||||||
Some(end) => {
|
Some(end) => {
|
||||||
self.tool_call_buf.push_str(&self.buf[..end]);
|
self.tool_call_buf.push_str(&self.buf[..end]);
|
||||||
|
self.tool_call_token_ids.extend(self.buf_token_ids.drain(..));
|
||||||
self.buf = self.buf[end + 12..].to_string();
|
self.buf = self.buf[end + 12..].to_string();
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
|
// Tool calls get re-tokenized from structured data
|
||||||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||||
self.call_counter += 1;
|
self.call_counter += 1;
|
||||||
pending.push(PendingToolCall {
|
pending.push(PendingToolCall {
|
||||||
|
|
@ -523,6 +574,7 @@ impl ResponseParser {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
self.tool_call_buf.clear();
|
self.tool_call_buf.clear();
|
||||||
|
self.tool_call_token_ids.clear();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
|
|
@ -551,6 +603,8 @@ impl ResponseParser {
|
||||||
if pos > 0 {
|
if pos > 0 {
|
||||||
self.content_parts.push(self.buf[..pos].to_string());
|
self.content_parts.push(self.buf[..pos].to_string());
|
||||||
}
|
}
|
||||||
|
// Move token IDs to content accumulator
|
||||||
|
self.content_token_ids.extend(self.buf_token_ids.drain(..));
|
||||||
if self.buf[pos..].starts_with("<think>") {
|
if self.buf[pos..].starts_with("<think>") {
|
||||||
self.buf = self.buf[pos + 7..].to_string();
|
self.buf = self.buf[pos + 7..].to_string();
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
|
|
@ -568,6 +622,7 @@ impl ResponseParser {
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
let safe = self.buf.floor_char_boundary(safe);
|
||||||
self.content_parts.push(self.buf[..safe].to_string());
|
self.content_parts.push(self.buf[..safe].to_string());
|
||||||
self.buf = self.buf[safe..].to_string();
|
self.buf = self.buf[safe..].to_string();
|
||||||
|
// Keep token IDs in buf (lookahead)
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -581,27 +636,28 @@ impl ResponseParser {
|
||||||
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn push_child_with_tokens(&self, ctx: &mut ContextState, body: NodeBody, token_ids: Vec<u32>) {
|
||||||
|
let leaf = NodeLeaf { body, token_ids, timestamp: None };
|
||||||
|
ctx.push_child(Section::Conversation, self.branch_idx, AstNode::Leaf(leaf));
|
||||||
|
}
|
||||||
|
|
||||||
fn flush_content(&mut self, ctx: &mut ContextState) {
|
fn flush_content(&mut self, ctx: &mut ContextState) {
|
||||||
if !self.content_parts.is_empty() {
|
if !self.content_parts.is_empty() {
|
||||||
let text: String = self.content_parts.drain(..).collect();
|
let text: String = self.content_parts.drain(..).collect();
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
self.push_child(ctx, AstNode::content(text));
|
let token_ids = std::mem::take(&mut self.content_token_ids);
|
||||||
|
self.push_child_with_tokens(ctx, NodeBody::Content(text), token_ids);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flush remaining buffer into the AST.
|
|
||||||
pub fn finish(mut self, ctx: &mut ContextState) {
|
pub fn finish(mut self, ctx: &mut ContextState) {
|
||||||
if !self.buf.is_empty() {
|
if !self.buf.is_empty() {
|
||||||
self.content_parts.push(std::mem::take(&mut self.buf));
|
self.content_parts.push(std::mem::take(&mut self.buf));
|
||||||
|
self.content_token_ids.extend(self.buf_token_ids.drain(..));
|
||||||
}
|
}
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Current display text (content accumulated since last drain).
|
|
||||||
pub fn display_content(&self) -> String {
|
|
||||||
self.content_parts.join("")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextState {
|
impl ContextState {
|
||||||
|
|
@ -838,7 +894,8 @@ mod tests {
|
||||||
let mut p = ResponseParser::new(0);
|
let mut p = ResponseParser::new(0);
|
||||||
let mut calls = Vec::new();
|
let mut calls = Vec::new();
|
||||||
for chunk in chunks {
|
for chunk in chunks {
|
||||||
calls.extend(p.feed(chunk, &mut ctx));
|
// Feed each chunk as a single token (id=0 for tests)
|
||||||
|
calls.extend(p.feed_token(chunk, 0, &mut ctx));
|
||||||
}
|
}
|
||||||
p.finish(&mut ctx);
|
p.finish(&mut ctx);
|
||||||
(ctx, calls)
|
(ctx, calls)
|
||||||
|
|
@ -900,7 +957,7 @@ mod tests {
|
||||||
ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||||||
let mut p = ResponseParser::new(0);
|
let mut p = ResponseParser::new(0);
|
||||||
for ch in text.chars() {
|
for ch in text.chars() {
|
||||||
p.feed(&ch.to_string(), &mut ctx);
|
p.feed_token(&ch.to_string(), 0, &mut ctx);
|
||||||
}
|
}
|
||||||
p.finish(&mut ctx);
|
p.finish(&mut ctx);
|
||||||
let b = bodies(assistant_children(&ctx));
|
let b = bodies(assistant_children(&ctx));
|
||||||
|
|
@ -917,7 +974,7 @@ mod tests {
|
||||||
let mut p = ResponseParser::new(0);
|
let mut p = ResponseParser::new(0);
|
||||||
let mut tool_calls = 0;
|
let mut tool_calls = 0;
|
||||||
for ch in text.chars() {
|
for ch in text.chars() {
|
||||||
tool_calls += p.feed(&ch.to_string(), &mut ctx).len();
|
tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len();
|
||||||
}
|
}
|
||||||
p.finish(&mut ctx);
|
p.finish(&mut ctx);
|
||||||
assert_eq!(tool_calls, 1);
|
assert_eq!(tool_calls, 1);
|
||||||
|
|
|
||||||
106
src/agent/mod.rs
106
src/agent/mod.rs
|
|
@ -339,77 +339,55 @@ impl Agent {
|
||||||
AstNode::branch(Role::Assistant, vec![]));
|
AstNode::branch(Role::Assistant, vec![]));
|
||||||
idx
|
idx
|
||||||
};
|
};
|
||||||
let mut parser = ResponseParser::new(branch_idx);
|
|
||||||
let mut pending_calls: Vec<PendingToolCall> = Vec::new();
|
|
||||||
let mut had_content = false;
|
|
||||||
let mut stream_error: Option<String> = None;
|
|
||||||
|
|
||||||
// Stream loop — no lock held across I/O
|
let parser = ResponseParser::new(branch_idx);
|
||||||
while let Some(event) = rx.recv().await {
|
let (mut tool_rx, parser_handle) = parser.run(rx, agent.clone());
|
||||||
match event {
|
|
||||||
api::StreamToken::Token { text, id: _ } => {
|
let mut pending_calls: Vec<PendingToolCall> = Vec::new();
|
||||||
had_content = true;
|
while let Some(call) = tool_rx.recv().await {
|
||||||
let mut ctx = agent.context.lock().await;
|
let call_clone = call.clone();
|
||||||
let calls = parser.feed(&text, &mut ctx);
|
let agent_handle = agent.clone();
|
||||||
drop(ctx);
|
let handle = tokio::spawn(async move {
|
||||||
for call in calls {
|
let args: serde_json::Value =
|
||||||
let call_clone = call.clone();
|
serde_json::from_str(&call_clone.arguments).unwrap_or_default();
|
||||||
let agent_handle = agent.clone();
|
let output = tools::dispatch_with_agent(
|
||||||
let handle = tokio::spawn(async move {
|
&call_clone.name, &args, Some(agent_handle),
|
||||||
let args: serde_json::Value =
|
).await;
|
||||||
serde_json::from_str(&call_clone.arguments).unwrap_or_default();
|
(call_clone, output)
|
||||||
let output = tools::dispatch_with_agent(
|
});
|
||||||
&call_clone.name, &args, Some(agent_handle),
|
active_tools.lock().unwrap().push(tools::ActiveToolCall {
|
||||||
).await;
|
id: call.id.clone(),
|
||||||
(call_clone, output)
|
name: call.name.clone(),
|
||||||
});
|
detail: call.arguments.clone(),
|
||||||
active_tools.lock().unwrap().push(tools::ActiveToolCall {
|
started: std::time::Instant::now(),
|
||||||
id: call.id.clone(),
|
background: false,
|
||||||
name: call.name.clone(),
|
handle,
|
||||||
detail: call.arguments.clone(),
|
});
|
||||||
started: std::time::Instant::now(),
|
pending_calls.push(call);
|
||||||
background: false,
|
|
||||||
handle,
|
|
||||||
});
|
|
||||||
pending_calls.push(call);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
api::StreamToken::Error(e) => {
|
|
||||||
stream_error = Some(e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
api::StreamToken::Done { usage } => {
|
|
||||||
if let Some(u) = usage {
|
|
||||||
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush parser remainder
|
// Check for stream/parse errors
|
||||||
parser.finish(&mut *agent.context.lock().await);
|
match parser_handle.await {
|
||||||
|
Ok(Err(e)) => {
|
||||||
// Handle errors
|
if context::is_context_overflow(&e) && overflow_retries < 2 {
|
||||||
if let Some(e) = stream_error {
|
overflow_retries += 1;
|
||||||
let err = anyhow::anyhow!("{}", e);
|
agent.state.lock().await.notify(
|
||||||
if context::is_context_overflow(&err) && overflow_retries < 2 {
|
format!("context overflow — retrying ({}/2)", overflow_retries));
|
||||||
overflow_retries += 1;
|
agent.compact().await;
|
||||||
agent.state.lock().await.notify(format!("context overflow — retrying ({}/2)", overflow_retries));
|
continue;
|
||||||
agent.compact().await;
|
}
|
||||||
continue;
|
return Err(e);
|
||||||
}
|
}
|
||||||
if context::is_stream_error(&err) && empty_retries < 2 {
|
Err(e) => return Err(anyhow::anyhow!("parser task panicked: {}", e)),
|
||||||
empty_retries += 1;
|
Ok(Ok(())) => {}
|
||||||
agent.state.lock().await.notify(format!("stream error — retrying ({}/2)", empty_retries));
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return Err(err);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty response — nudge and retry
|
// Empty response — nudge and retry
|
||||||
if !had_content && pending_calls.is_empty() {
|
let has_content = {
|
||||||
|
let ctx = agent.context.lock().await;
|
||||||
|
!ctx.conversation()[branch_idx].children().is_empty()
|
||||||
|
};
|
||||||
|
if !has_content && pending_calls.is_empty() {
|
||||||
if empty_retries < 2 {
|
if empty_retries < 2 {
|
||||||
empty_retries += 1;
|
empty_retries += 1;
|
||||||
agent.push_node(AstNode::user_msg(
|
agent.push_node(AstNode::user_msg(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue