From a14e85afe14b9a5d61c2abd869fce84bc43fd876 Mon Sep 17 00:00:00 2001 From: ProofOfConcept Date: Sat, 4 Apr 2026 18:05:16 -0400 Subject: [PATCH] api: extract collect_stream() from agent turn loop Move the entire stream event processing loop (content accumulation, leaked tool call detection/dispatch, ToolCallDelta assembly, UI forwarding, display buffering) into api::collect_stream(). The turn loop now calls collect_stream() and processes the StreamResult. Also move FunctionCall, ToolCall, ToolCallDelta to api/types.rs where they belong (API wire format, not tool definitions). Move parsing.rs to api/parsing.rs. Co-Authored-By: Proof of Concept Signed-off-by: Kent Overstreet --- src/agent/api/mod.rs | 144 ++++++++++++++++++++++++++++++++++++++++- src/agent/mod.rs | 127 ++++-------------------------------- src/agent/tools/mod.rs | 2 +- 3 files changed, 155 insertions(+), 118 deletions(-) diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index 169f8e4..bde1704 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -18,9 +18,9 @@ use std::time::{Duration, Instant}; use tokio::sync::mpsc; -use crate::agent::tools::{self as agent_tools}; -use types::{ToolCall, FunctionCall}; -use crate::user::ui_channel::{UiMessage, UiSender}; +use crate::agent::tools::{self as agent_tools, summarize_args, ActiveToolCall}; +pub use types::ToolCall; +use crate::user::ui_channel::{UiMessage, UiSender, StreamTarget}; /// A JoinHandle that aborts its task when dropped. pub struct AbortOnDrop(tokio::task::JoinHandle<()>); @@ -594,3 +594,141 @@ pub(crate) fn log_diagnostics( } } } + +// --------------------------------------------------------------------------- +// Stream collection — assembles StreamEvents into a complete response +// --------------------------------------------------------------------------- + +/// Result of collecting a complete response from the stream. +pub struct StreamResult { + pub content: String, + pub tool_calls: Vec, + pub usage: Option, + pub finish_reason: Option, + pub error: Option, + /// Remaining display buffer (caller should flush if not in a tool call). + pub display_buf: String, + /// Whether we were mid-tool-call when the stream ended. + pub in_tool_call: bool, +} + +/// Collect stream events into a complete response. Handles: +/// - Content accumulation and display buffering +/// - Leaked tool call detection and dispatch (Qwen XML in content) +/// - Structured tool call delta assembly (OpenAI-style) +/// - UI forwarding (text deltas, reasoning, tool call notifications) +pub async fn collect_stream( + rx: &mut mpsc::UnboundedReceiver, + ui_tx: &UiSender, + target: StreamTarget, + agent: &std::sync::Arc>, + active_tools: &crate::user::ui_channel::SharedActiveTools, +) -> StreamResult { + let mut content = String::new(); + let mut tool_calls: Vec = Vec::new(); + let mut usage = None; + let mut finish_reason = None; + let mut in_tool_call = false; + let mut tool_call_buf = String::new(); + let mut error = None; + let mut first_content = true; + let mut display_buf = String::new(); + + while let Some(event) = rx.recv().await { + match event { + StreamEvent::Content(text) => { + if first_content { + let _ = ui_tx.send(UiMessage::Activity("streaming...".into())); + first_content = false; + } + content.push_str(&text); + + if in_tool_call { + tool_call_buf.push_str(&text); + if let Some(end) = tool_call_buf.find("") { + let body = &tool_call_buf[..end]; + if let Some(call) = parsing::parse_tool_call_body(body) { + let args: serde_json::Value = + serde_json::from_str(&call.function.arguments).unwrap_or_default(); + let args_summary = summarize_args(&call.function.name, &args); + let _ = ui_tx.send(UiMessage::ToolCall { + name: call.function.name.clone(), + args_summary: args_summary.clone(), + }); + let is_background = args.get("run_in_background") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let call_id = call.id.clone(); + let call_name = call.function.name.clone(); + let agent_handle = agent.clone(); + let handle = tokio::spawn(async move { + let output = agent_tools::dispatch_with_agent( + &call.function.name, &args, Some(agent_handle)).await; + (call, output) + }); + active_tools.lock().unwrap().push(ActiveToolCall { + id: call_id, + name: call_name, + detail: args_summary, + started: std::time::Instant::now(), + background: is_background, + handle, + }); + } + let remaining = tool_call_buf[end + "".len()..].to_string(); + tool_call_buf.clear(); + in_tool_call = false; + if !remaining.trim().is_empty() { + display_buf.push_str(&remaining); + } + } + } else { + display_buf.push_str(&text); + if let Some(pos) = display_buf.find("") { + let before = &display_buf[..pos]; + if !before.is_empty() { + let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target)); + } + display_buf.clear(); + in_tool_call = true; + } else { + let safe = display_buf.len().saturating_sub(10); + let safe = display_buf.floor_char_boundary(safe); + if safe > 0 { + let flush = display_buf[..safe].to_string(); + display_buf = display_buf[safe..].to_string(); + let _ = ui_tx.send(UiMessage::TextDelta(flush, target)); + } + } + } + } + StreamEvent::Reasoning(text) => { + let _ = ui_tx.send(UiMessage::Reasoning(text)); + } + StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => { + while tool_calls.len() <= index { + tool_calls.push(ToolCall { + id: String::new(), + call_type: "function".to_string(), + function: FunctionCall { name: String::new(), arguments: String::new() }, + }); + } + if let Some(id) = id { tool_calls[index].id = id; } + if let Some(ct) = call_type { tool_calls[index].call_type = ct; } + if let Some(n) = name { tool_calls[index].function.name = n; } + if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); } + } + StreamEvent::Usage(u) => usage = Some(u), + StreamEvent::Finished { reason, .. } => { + finish_reason = Some(reason); + break; + } + StreamEvent::Error(e) => { + error = Some(e); + break; + } + } + } + + StreamResult { content, tool_calls, usage, finish_reason, error, display_buf, in_tool_call } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 33a2738..554af06 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -23,14 +23,12 @@ use std::sync::Arc; use anyhow::Result; use tiktoken_rs::CoreBPE; -use api::{ApiClient, StreamEvent}; -use context as journal; -use tools::{ToolCall, FunctionCall, summarize_args}; +use api::{ApiClient, ToolCall}; +use api::types::{ContentPart, Message, MessageContent, Role}; +use context::{ConversationEntry, ContextState, ContextBudget}; +use tools::{summarize_args, working_stack}; use crate::user::log::ConversationLog; -use crate::agent::api::types::*; -use crate::agent::context::{ConversationEntry, ContextState, ContextBudget}; -use crate::agent::tools::working_stack; use crate::user::ui_channel::{ContextSection, SharedContextState, StreamTarget, StatusInfo, UiMessage, UiSender}; /// Result of a single agent turn. @@ -104,7 +102,7 @@ pub struct Agent { pub active_tools: crate::user::ui_channel::SharedActiveTools, } -fn render_journal(entries: &[journal::JournalEntry]) -> String { +fn render_journal(entries: &[context::JournalEntry]) -> String { if entries.is_empty() { return String::new(); } let mut text = String::from("[Earlier — from your journal]\n\n"); for entry in entries { @@ -316,112 +314,13 @@ impl Agent { // --- Lock released --- // --- Stream loop (no lock) --- - let mut content = String::new(); - let mut tool_calls: Vec = Vec::new(); - let mut usage = None; - let mut finish_reason = None; - let mut in_tool_call = false; - let mut tool_call_buf = String::new(); - let mut stream_error = None; - let mut first_content = true; - let mut display_buf = String::new(); - - while let Some(event) = rx.recv().await { - match event { - StreamEvent::Content(text) => { - if first_content { - let _ = ui_tx.send(UiMessage::Activity("streaming...".into())); - first_content = false; - } - content.push_str(&text); - - if in_tool_call { - tool_call_buf.push_str(&text); - if let Some(end) = tool_call_buf.find("") { - let body = &tool_call_buf[..end]; - if let Some(call) = crate::agent::api::parsing::parse_tool_call_body(body) { - let args: serde_json::Value = - serde_json::from_str(&call.function.arguments).unwrap_or_default(); - let args_summary = summarize_args(&call.function.name, &args); - let _ = ui_tx.send(UiMessage::ToolCall { - name: call.function.name.clone(), - args_summary: args_summary.clone(), - }); - let is_background = args.get("run_in_background") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - let call_id = call.id.clone(); - let call_name = call.function.name.clone(); - let agent_handle = agent.clone(); - let handle = tokio::spawn(async move { - let output = tools::dispatch_with_agent(&call.function.name, &args, Some(agent_handle)).await; - (call, output) - }); - active_tools.lock().unwrap().push( - tools::ActiveToolCall { - id: call_id, - name: call_name, - detail: args_summary, - started: std::time::Instant::now(), - background: is_background, - handle, - } - ); - } - let remaining = tool_call_buf[end + "".len()..].to_string(); - tool_call_buf.clear(); - in_tool_call = false; - if !remaining.trim().is_empty() { - display_buf.push_str(&remaining); - } - } - } else { - display_buf.push_str(&text); - if let Some(pos) = display_buf.find("") { - let before = &display_buf[..pos]; - if !before.is_empty() { - let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target)); - } - display_buf.clear(); - in_tool_call = true; - } else { - let safe = display_buf.len().saturating_sub(10); - let safe = display_buf.floor_char_boundary(safe); - if safe > 0 { - let flush = display_buf[..safe].to_string(); - display_buf = display_buf[safe..].to_string(); - let _ = ui_tx.send(UiMessage::TextDelta(flush, target)); - } - } - } - } - StreamEvent::Reasoning(text) => { - let _ = ui_tx.send(UiMessage::Reasoning(text)); - } - StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => { - while tool_calls.len() <= index { - tool_calls.push(ToolCall { - id: String::new(), - call_type: "function".to_string(), - function: FunctionCall { name: String::new(), arguments: String::new() }, - }); - } - if let Some(id) = id { tool_calls[index].id = id; } - if let Some(ct) = call_type { tool_calls[index].call_type = ct; } - if let Some(n) = name { tool_calls[index].function.name = n; } - if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); } - } - StreamEvent::Usage(u) => usage = Some(u), - StreamEvent::Finished { reason, .. } => { - finish_reason = Some(reason); - break; - } - StreamEvent::Error(e) => { - stream_error = Some(e); - break; - } - } - } + let sr = api::collect_stream( + &mut rx, ui_tx, target, &agent, &active_tools, + ).await; + let api::StreamResult { + content, tool_calls, usage, finish_reason, + error: stream_error, display_buf, in_tool_call, + } = sr; // --- Stream complete --- // --- Lock 3: process results --- @@ -918,7 +817,7 @@ impl Agent { if total_tokens + tokens > journal_budget && !entries.is_empty() { break; } - entries.push(journal::JournalEntry { + entries.push(context::JournalEntry { timestamp: chrono::DateTime::from_timestamp(node.created_at, 0) .unwrap_or_default(), content: node.content.clone(), diff --git a/src/agent/tools/mod.rs b/src/agent/tools/mod.rs index 041ab45..95ef89f 100644 --- a/src/agent/tools/mod.rs +++ b/src/agent/tools/mod.rs @@ -57,7 +57,7 @@ impl Tool { } // Re-export API wire types used by the agent turn loop -pub use super::api::types::{FunctionCall, ToolCall, ToolCallDelta}; +use super::api::types::ToolCall; /// A tool call in flight — metadata for TUI + JoinHandle for /// result collection and cancellation.