api: extract collect_stream() from agent turn loop

Move the entire stream event processing loop (content accumulation,
leaked tool call detection/dispatch, ToolCallDelta assembly, UI
forwarding, display buffering) into api::collect_stream(). The turn
loop now calls collect_stream() and processes the StreamResult.

Also move FunctionCall, ToolCall, ToolCallDelta to api/types.rs where
they belong (API wire format, not tool definitions). Move parsing.rs
to api/parsing.rs.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
ProofOfConcept 2026-04-04 18:05:16 -04:00 committed by Kent Overstreet
parent 6845644f7b
commit a14e85afe1
3 changed files with 155 additions and 118 deletions

View file

@ -18,9 +18,9 @@ use std::time::{Duration, Instant};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::agent::tools::{self as agent_tools}; use crate::agent::tools::{self as agent_tools, summarize_args, ActiveToolCall};
use types::{ToolCall, FunctionCall}; pub use types::ToolCall;
use crate::user::ui_channel::{UiMessage, UiSender}; use crate::user::ui_channel::{UiMessage, UiSender, StreamTarget};
/// A JoinHandle that aborts its task when dropped. /// A JoinHandle that aborts its task when dropped.
pub struct AbortOnDrop(tokio::task::JoinHandle<()>); pub struct AbortOnDrop(tokio::task::JoinHandle<()>);
@ -594,3 +594,141 @@ pub(crate) fn log_diagnostics(
} }
} }
} }
// ---------------------------------------------------------------------------
// Stream collection — assembles StreamEvents into a complete response
// ---------------------------------------------------------------------------
/// Result of collecting a complete response from the stream.
pub struct StreamResult {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub usage: Option<Usage>,
pub finish_reason: Option<String>,
pub error: Option<String>,
/// Remaining display buffer (caller should flush if not in a tool call).
pub display_buf: String,
/// Whether we were mid-tool-call when the stream ended.
pub in_tool_call: bool,
}
/// Collect stream events into a complete response. Handles:
/// - Content accumulation and display buffering
/// - Leaked tool call detection and dispatch (Qwen XML in content)
/// - Structured tool call delta assembly (OpenAI-style)
/// - UI forwarding (text deltas, reasoning, tool call notifications)
pub async fn collect_stream(
rx: &mut mpsc::UnboundedReceiver<StreamEvent>,
ui_tx: &UiSender,
target: StreamTarget,
agent: &std::sync::Arc<tokio::sync::Mutex<super::Agent>>,
active_tools: &crate::user::ui_channel::SharedActiveTools,
) -> StreamResult {
let mut content = String::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
let mut usage = None;
let mut finish_reason = None;
let mut in_tool_call = false;
let mut tool_call_buf = String::new();
let mut error = None;
let mut first_content = true;
let mut display_buf = String::new();
while let Some(event) = rx.recv().await {
match event {
StreamEvent::Content(text) => {
if first_content {
let _ = ui_tx.send(UiMessage::Activity("streaming...".into()));
first_content = false;
}
content.push_str(&text);
if in_tool_call {
tool_call_buf.push_str(&text);
if let Some(end) = tool_call_buf.find("</tool_call>") {
let body = &tool_call_buf[..end];
if let Some(call) = parsing::parse_tool_call_body(body) {
let args: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or_default();
let args_summary = summarize_args(&call.function.name, &args);
let _ = ui_tx.send(UiMessage::ToolCall {
name: call.function.name.clone(),
args_summary: args_summary.clone(),
});
let is_background = args.get("run_in_background")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let call_id = call.id.clone();
let call_name = call.function.name.clone();
let agent_handle = agent.clone();
let handle = tokio::spawn(async move {
let output = agent_tools::dispatch_with_agent(
&call.function.name, &args, Some(agent_handle)).await;
(call, output)
});
active_tools.lock().unwrap().push(ActiveToolCall {
id: call_id,
name: call_name,
detail: args_summary,
started: std::time::Instant::now(),
background: is_background,
handle,
});
}
let remaining = tool_call_buf[end + "</tool_call>".len()..].to_string();
tool_call_buf.clear();
in_tool_call = false;
if !remaining.trim().is_empty() {
display_buf.push_str(&remaining);
}
}
} else {
display_buf.push_str(&text);
if let Some(pos) = display_buf.find("<tool_call>") {
let before = &display_buf[..pos];
if !before.is_empty() {
let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target));
}
display_buf.clear();
in_tool_call = true;
} else {
let safe = display_buf.len().saturating_sub(10);
let safe = display_buf.floor_char_boundary(safe);
if safe > 0 {
let flush = display_buf[..safe].to_string();
display_buf = display_buf[safe..].to_string();
let _ = ui_tx.send(UiMessage::TextDelta(flush, target));
}
}
}
}
StreamEvent::Reasoning(text) => {
let _ = ui_tx.send(UiMessage::Reasoning(text));
}
StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => {
while tool_calls.len() <= index {
tool_calls.push(ToolCall {
id: String::new(),
call_type: "function".to_string(),
function: FunctionCall { name: String::new(), arguments: String::new() },
});
}
if let Some(id) = id { tool_calls[index].id = id; }
if let Some(ct) = call_type { tool_calls[index].call_type = ct; }
if let Some(n) = name { tool_calls[index].function.name = n; }
if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); }
}
StreamEvent::Usage(u) => usage = Some(u),
StreamEvent::Finished { reason, .. } => {
finish_reason = Some(reason);
break;
}
StreamEvent::Error(e) => {
error = Some(e);
break;
}
}
}
StreamResult { content, tool_calls, usage, finish_reason, error, display_buf, in_tool_call }
}

View file

@ -23,14 +23,12 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use tiktoken_rs::CoreBPE; use tiktoken_rs::CoreBPE;
use api::{ApiClient, StreamEvent}; use api::{ApiClient, ToolCall};
use context as journal; use api::types::{ContentPart, Message, MessageContent, Role};
use tools::{ToolCall, FunctionCall, summarize_args}; use context::{ConversationEntry, ContextState, ContextBudget};
use tools::{summarize_args, working_stack};
use crate::user::log::ConversationLog; use crate::user::log::ConversationLog;
use crate::agent::api::types::*;
use crate::agent::context::{ConversationEntry, ContextState, ContextBudget};
use crate::agent::tools::working_stack;
use crate::user::ui_channel::{ContextSection, SharedContextState, StreamTarget, StatusInfo, UiMessage, UiSender}; use crate::user::ui_channel::{ContextSection, SharedContextState, StreamTarget, StatusInfo, UiMessage, UiSender};
/// Result of a single agent turn. /// Result of a single agent turn.
@ -104,7 +102,7 @@ pub struct Agent {
pub active_tools: crate::user::ui_channel::SharedActiveTools, pub active_tools: crate::user::ui_channel::SharedActiveTools,
} }
fn render_journal(entries: &[journal::JournalEntry]) -> String { fn render_journal(entries: &[context::JournalEntry]) -> String {
if entries.is_empty() { return String::new(); } if entries.is_empty() { return String::new(); }
let mut text = String::from("[Earlier — from your journal]\n\n"); let mut text = String::from("[Earlier — from your journal]\n\n");
for entry in entries { for entry in entries {
@ -316,112 +314,13 @@ impl Agent {
// --- Lock released --- // --- Lock released ---
// --- Stream loop (no lock) --- // --- Stream loop (no lock) ---
let mut content = String::new(); let sr = api::collect_stream(
let mut tool_calls: Vec<ToolCall> = Vec::new(); &mut rx, ui_tx, target, &agent, &active_tools,
let mut usage = None; ).await;
let mut finish_reason = None; let api::StreamResult {
let mut in_tool_call = false; content, tool_calls, usage, finish_reason,
let mut tool_call_buf = String::new(); error: stream_error, display_buf, in_tool_call,
let mut stream_error = None; } = sr;
let mut first_content = true;
let mut display_buf = String::new();
while let Some(event) = rx.recv().await {
match event {
StreamEvent::Content(text) => {
if first_content {
let _ = ui_tx.send(UiMessage::Activity("streaming...".into()));
first_content = false;
}
content.push_str(&text);
if in_tool_call {
tool_call_buf.push_str(&text);
if let Some(end) = tool_call_buf.find("</tool_call>") {
let body = &tool_call_buf[..end];
if let Some(call) = crate::agent::api::parsing::parse_tool_call_body(body) {
let args: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or_default();
let args_summary = summarize_args(&call.function.name, &args);
let _ = ui_tx.send(UiMessage::ToolCall {
name: call.function.name.clone(),
args_summary: args_summary.clone(),
});
let is_background = args.get("run_in_background")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let call_id = call.id.clone();
let call_name = call.function.name.clone();
let agent_handle = agent.clone();
let handle = tokio::spawn(async move {
let output = tools::dispatch_with_agent(&call.function.name, &args, Some(agent_handle)).await;
(call, output)
});
active_tools.lock().unwrap().push(
tools::ActiveToolCall {
id: call_id,
name: call_name,
detail: args_summary,
started: std::time::Instant::now(),
background: is_background,
handle,
}
);
}
let remaining = tool_call_buf[end + "</tool_call>".len()..].to_string();
tool_call_buf.clear();
in_tool_call = false;
if !remaining.trim().is_empty() {
display_buf.push_str(&remaining);
}
}
} else {
display_buf.push_str(&text);
if let Some(pos) = display_buf.find("<tool_call>") {
let before = &display_buf[..pos];
if !before.is_empty() {
let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target));
}
display_buf.clear();
in_tool_call = true;
} else {
let safe = display_buf.len().saturating_sub(10);
let safe = display_buf.floor_char_boundary(safe);
if safe > 0 {
let flush = display_buf[..safe].to_string();
display_buf = display_buf[safe..].to_string();
let _ = ui_tx.send(UiMessage::TextDelta(flush, target));
}
}
}
}
StreamEvent::Reasoning(text) => {
let _ = ui_tx.send(UiMessage::Reasoning(text));
}
StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => {
while tool_calls.len() <= index {
tool_calls.push(ToolCall {
id: String::new(),
call_type: "function".to_string(),
function: FunctionCall { name: String::new(), arguments: String::new() },
});
}
if let Some(id) = id { tool_calls[index].id = id; }
if let Some(ct) = call_type { tool_calls[index].call_type = ct; }
if let Some(n) = name { tool_calls[index].function.name = n; }
if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); }
}
StreamEvent::Usage(u) => usage = Some(u),
StreamEvent::Finished { reason, .. } => {
finish_reason = Some(reason);
break;
}
StreamEvent::Error(e) => {
stream_error = Some(e);
break;
}
}
}
// --- Stream complete --- // --- Stream complete ---
// --- Lock 3: process results --- // --- Lock 3: process results ---
@ -918,7 +817,7 @@ impl Agent {
if total_tokens + tokens > journal_budget && !entries.is_empty() { if total_tokens + tokens > journal_budget && !entries.is_empty() {
break; break;
} }
entries.push(journal::JournalEntry { entries.push(context::JournalEntry {
timestamp: chrono::DateTime::from_timestamp(node.created_at, 0) timestamp: chrono::DateTime::from_timestamp(node.created_at, 0)
.unwrap_or_default(), .unwrap_or_default(),
content: node.content.clone(), content: node.content.clone(),

View file

@ -57,7 +57,7 @@ impl Tool {
} }
// Re-export API wire types used by the agent turn loop // Re-export API wire types used by the agent turn loop
pub use super::api::types::{FunctionCall, ToolCall, ToolCallDelta}; use super::api::types::ToolCall;
/// A tool call in flight — metadata for TUI + JoinHandle for /// A tool call in flight — metadata for TUI + JoinHandle for
/// result collection and cancellation. /// result collection and cancellation.