api: extract collect_stream() from agent turn loop
Move the entire stream event processing loop (content accumulation, leaked tool call detection/dispatch, ToolCallDelta assembly, UI forwarding, display buffering) into api::collect_stream(). The turn loop now calls collect_stream() and processes the StreamResult. Also move FunctionCall, ToolCall, ToolCallDelta to api/types.rs where they belong (API wire format, not tool definitions). Move parsing.rs to api/parsing.rs. Co-Authored-By: Proof of Concept <poc@bcachefs.org> Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
6845644f7b
commit
a14e85afe1
3 changed files with 155 additions and 118 deletions
|
|
@ -18,9 +18,9 @@ use std::time::{Duration, Instant};
|
|||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::agent::tools::{self as agent_tools};
|
||||
use types::{ToolCall, FunctionCall};
|
||||
use crate::user::ui_channel::{UiMessage, UiSender};
|
||||
use crate::agent::tools::{self as agent_tools, summarize_args, ActiveToolCall};
|
||||
pub use types::ToolCall;
|
||||
use crate::user::ui_channel::{UiMessage, UiSender, StreamTarget};
|
||||
|
||||
/// A JoinHandle that aborts its task when dropped.
|
||||
pub struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||
|
|
@ -594,3 +594,141 @@ pub(crate) fn log_diagnostics(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stream collection — assembles StreamEvents into a complete response
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Result of collecting a complete response from the stream.
|
||||
pub struct StreamResult {
|
||||
pub content: String,
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
pub usage: Option<Usage>,
|
||||
pub finish_reason: Option<String>,
|
||||
pub error: Option<String>,
|
||||
/// Remaining display buffer (caller should flush if not in a tool call).
|
||||
pub display_buf: String,
|
||||
/// Whether we were mid-tool-call when the stream ended.
|
||||
pub in_tool_call: bool,
|
||||
}
|
||||
|
||||
/// Collect stream events into a complete response. Handles:
|
||||
/// - Content accumulation and display buffering
|
||||
/// - Leaked tool call detection and dispatch (Qwen XML in content)
|
||||
/// - Structured tool call delta assembly (OpenAI-style)
|
||||
/// - UI forwarding (text deltas, reasoning, tool call notifications)
|
||||
pub async fn collect_stream(
|
||||
rx: &mut mpsc::UnboundedReceiver<StreamEvent>,
|
||||
ui_tx: &UiSender,
|
||||
target: StreamTarget,
|
||||
agent: &std::sync::Arc<tokio::sync::Mutex<super::Agent>>,
|
||||
active_tools: &crate::user::ui_channel::SharedActiveTools,
|
||||
) -> StreamResult {
|
||||
let mut content = String::new();
|
||||
let mut tool_calls: Vec<ToolCall> = Vec::new();
|
||||
let mut usage = None;
|
||||
let mut finish_reason = None;
|
||||
let mut in_tool_call = false;
|
||||
let mut tool_call_buf = String::new();
|
||||
let mut error = None;
|
||||
let mut first_content = true;
|
||||
let mut display_buf = String::new();
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
StreamEvent::Content(text) => {
|
||||
if first_content {
|
||||
let _ = ui_tx.send(UiMessage::Activity("streaming...".into()));
|
||||
first_content = false;
|
||||
}
|
||||
content.push_str(&text);
|
||||
|
||||
if in_tool_call {
|
||||
tool_call_buf.push_str(&text);
|
||||
if let Some(end) = tool_call_buf.find("</tool_call>") {
|
||||
let body = &tool_call_buf[..end];
|
||||
if let Some(call) = parsing::parse_tool_call_body(body) {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
||||
let args_summary = summarize_args(&call.function.name, &args);
|
||||
let _ = ui_tx.send(UiMessage::ToolCall {
|
||||
name: call.function.name.clone(),
|
||||
args_summary: args_summary.clone(),
|
||||
});
|
||||
let is_background = args.get("run_in_background")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let call_id = call.id.clone();
|
||||
let call_name = call.function.name.clone();
|
||||
let agent_handle = agent.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let output = agent_tools::dispatch_with_agent(
|
||||
&call.function.name, &args, Some(agent_handle)).await;
|
||||
(call, output)
|
||||
});
|
||||
active_tools.lock().unwrap().push(ActiveToolCall {
|
||||
id: call_id,
|
||||
name: call_name,
|
||||
detail: args_summary,
|
||||
started: std::time::Instant::now(),
|
||||
background: is_background,
|
||||
handle,
|
||||
});
|
||||
}
|
||||
let remaining = tool_call_buf[end + "</tool_call>".len()..].to_string();
|
||||
tool_call_buf.clear();
|
||||
in_tool_call = false;
|
||||
if !remaining.trim().is_empty() {
|
||||
display_buf.push_str(&remaining);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
display_buf.push_str(&text);
|
||||
if let Some(pos) = display_buf.find("<tool_call>") {
|
||||
let before = &display_buf[..pos];
|
||||
if !before.is_empty() {
|
||||
let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target));
|
||||
}
|
||||
display_buf.clear();
|
||||
in_tool_call = true;
|
||||
} else {
|
||||
let safe = display_buf.len().saturating_sub(10);
|
||||
let safe = display_buf.floor_char_boundary(safe);
|
||||
if safe > 0 {
|
||||
let flush = display_buf[..safe].to_string();
|
||||
display_buf = display_buf[safe..].to_string();
|
||||
let _ = ui_tx.send(UiMessage::TextDelta(flush, target));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
StreamEvent::Reasoning(text) => {
|
||||
let _ = ui_tx.send(UiMessage::Reasoning(text));
|
||||
}
|
||||
StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => {
|
||||
while tool_calls.len() <= index {
|
||||
tool_calls.push(ToolCall {
|
||||
id: String::new(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name: String::new(), arguments: String::new() },
|
||||
});
|
||||
}
|
||||
if let Some(id) = id { tool_calls[index].id = id; }
|
||||
if let Some(ct) = call_type { tool_calls[index].call_type = ct; }
|
||||
if let Some(n) = name { tool_calls[index].function.name = n; }
|
||||
if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); }
|
||||
}
|
||||
StreamEvent::Usage(u) => usage = Some(u),
|
||||
StreamEvent::Finished { reason, .. } => {
|
||||
finish_reason = Some(reason);
|
||||
break;
|
||||
}
|
||||
StreamEvent::Error(e) => {
|
||||
error = Some(e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamResult { content, tool_calls, usage, finish_reason, error, display_buf, in_tool_call }
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue