api: extract collect_stream() from agent turn loop
Move the entire stream event processing loop (content accumulation, leaked tool call detection/dispatch, ToolCallDelta assembly, UI forwarding, display buffering) into api::collect_stream(). The turn loop now calls collect_stream() and processes the StreamResult. Also move FunctionCall, ToolCall, ToolCallDelta to api/types.rs where they belong (API wire format, not tool definitions). Move parsing.rs to api/parsing.rs. Co-Authored-By: Proof of Concept <poc@bcachefs.org> Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
6845644f7b
commit
a14e85afe1
3 changed files with 155 additions and 118 deletions
|
|
@ -18,9 +18,9 @@ use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use crate::agent::tools::{self as agent_tools};
|
use crate::agent::tools::{self as agent_tools, summarize_args, ActiveToolCall};
|
||||||
use types::{ToolCall, FunctionCall};
|
pub use types::ToolCall;
|
||||||
use crate::user::ui_channel::{UiMessage, UiSender};
|
use crate::user::ui_channel::{UiMessage, UiSender, StreamTarget};
|
||||||
|
|
||||||
/// A JoinHandle that aborts its task when dropped.
|
/// A JoinHandle that aborts its task when dropped.
|
||||||
pub struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
pub struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||||
|
|
@ -594,3 +594,141 @@ pub(crate) fn log_diagnostics(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stream collection — assembles StreamEvents into a complete response
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Result of collecting a complete response from the stream.
|
||||||
|
pub struct StreamResult {
|
||||||
|
pub content: String,
|
||||||
|
pub tool_calls: Vec<ToolCall>,
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
/// Remaining display buffer (caller should flush if not in a tool call).
|
||||||
|
pub display_buf: String,
|
||||||
|
/// Whether we were mid-tool-call when the stream ended.
|
||||||
|
pub in_tool_call: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collect stream events into a complete response. Handles:
|
||||||
|
/// - Content accumulation and display buffering
|
||||||
|
/// - Leaked tool call detection and dispatch (Qwen XML in content)
|
||||||
|
/// - Structured tool call delta assembly (OpenAI-style)
|
||||||
|
/// - UI forwarding (text deltas, reasoning, tool call notifications)
|
||||||
|
pub async fn collect_stream(
|
||||||
|
rx: &mut mpsc::UnboundedReceiver<StreamEvent>,
|
||||||
|
ui_tx: &UiSender,
|
||||||
|
target: StreamTarget,
|
||||||
|
agent: &std::sync::Arc<tokio::sync::Mutex<super::Agent>>,
|
||||||
|
active_tools: &crate::user::ui_channel::SharedActiveTools,
|
||||||
|
) -> StreamResult {
|
||||||
|
let mut content = String::new();
|
||||||
|
let mut tool_calls: Vec<ToolCall> = Vec::new();
|
||||||
|
let mut usage = None;
|
||||||
|
let mut finish_reason = None;
|
||||||
|
let mut in_tool_call = false;
|
||||||
|
let mut tool_call_buf = String::new();
|
||||||
|
let mut error = None;
|
||||||
|
let mut first_content = true;
|
||||||
|
let mut display_buf = String::new();
|
||||||
|
|
||||||
|
while let Some(event) = rx.recv().await {
|
||||||
|
match event {
|
||||||
|
StreamEvent::Content(text) => {
|
||||||
|
if first_content {
|
||||||
|
let _ = ui_tx.send(UiMessage::Activity("streaming...".into()));
|
||||||
|
first_content = false;
|
||||||
|
}
|
||||||
|
content.push_str(&text);
|
||||||
|
|
||||||
|
if in_tool_call {
|
||||||
|
tool_call_buf.push_str(&text);
|
||||||
|
if let Some(end) = tool_call_buf.find("</tool_call>") {
|
||||||
|
let body = &tool_call_buf[..end];
|
||||||
|
if let Some(call) = parsing::parse_tool_call_body(body) {
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
||||||
|
let args_summary = summarize_args(&call.function.name, &args);
|
||||||
|
let _ = ui_tx.send(UiMessage::ToolCall {
|
||||||
|
name: call.function.name.clone(),
|
||||||
|
args_summary: args_summary.clone(),
|
||||||
|
});
|
||||||
|
let is_background = args.get("run_in_background")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let call_id = call.id.clone();
|
||||||
|
let call_name = call.function.name.clone();
|
||||||
|
let agent_handle = agent.clone();
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let output = agent_tools::dispatch_with_agent(
|
||||||
|
&call.function.name, &args, Some(agent_handle)).await;
|
||||||
|
(call, output)
|
||||||
|
});
|
||||||
|
active_tools.lock().unwrap().push(ActiveToolCall {
|
||||||
|
id: call_id,
|
||||||
|
name: call_name,
|
||||||
|
detail: args_summary,
|
||||||
|
started: std::time::Instant::now(),
|
||||||
|
background: is_background,
|
||||||
|
handle,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let remaining = tool_call_buf[end + "</tool_call>".len()..].to_string();
|
||||||
|
tool_call_buf.clear();
|
||||||
|
in_tool_call = false;
|
||||||
|
if !remaining.trim().is_empty() {
|
||||||
|
display_buf.push_str(&remaining);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
display_buf.push_str(&text);
|
||||||
|
if let Some(pos) = display_buf.find("<tool_call>") {
|
||||||
|
let before = &display_buf[..pos];
|
||||||
|
if !before.is_empty() {
|
||||||
|
let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target));
|
||||||
|
}
|
||||||
|
display_buf.clear();
|
||||||
|
in_tool_call = true;
|
||||||
|
} else {
|
||||||
|
let safe = display_buf.len().saturating_sub(10);
|
||||||
|
let safe = display_buf.floor_char_boundary(safe);
|
||||||
|
if safe > 0 {
|
||||||
|
let flush = display_buf[..safe].to_string();
|
||||||
|
display_buf = display_buf[safe..].to_string();
|
||||||
|
let _ = ui_tx.send(UiMessage::TextDelta(flush, target));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StreamEvent::Reasoning(text) => {
|
||||||
|
let _ = ui_tx.send(UiMessage::Reasoning(text));
|
||||||
|
}
|
||||||
|
StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => {
|
||||||
|
while tool_calls.len() <= index {
|
||||||
|
tool_calls.push(ToolCall {
|
||||||
|
id: String::new(),
|
||||||
|
call_type: "function".to_string(),
|
||||||
|
function: FunctionCall { name: String::new(), arguments: String::new() },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(id) = id { tool_calls[index].id = id; }
|
||||||
|
if let Some(ct) = call_type { tool_calls[index].call_type = ct; }
|
||||||
|
if let Some(n) = name { tool_calls[index].function.name = n; }
|
||||||
|
if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); }
|
||||||
|
}
|
||||||
|
StreamEvent::Usage(u) => usage = Some(u),
|
||||||
|
StreamEvent::Finished { reason, .. } => {
|
||||||
|
finish_reason = Some(reason);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
StreamEvent::Error(e) => {
|
||||||
|
error = Some(e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamResult { content, tool_calls, usage, finish_reason, error, display_buf, in_tool_call }
|
||||||
|
}
|
||||||
|
|
|
||||||
127
src/agent/mod.rs
127
src/agent/mod.rs
|
|
@ -23,14 +23,12 @@ use std::sync::Arc;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use tiktoken_rs::CoreBPE;
|
use tiktoken_rs::CoreBPE;
|
||||||
|
|
||||||
use api::{ApiClient, StreamEvent};
|
use api::{ApiClient, ToolCall};
|
||||||
use context as journal;
|
use api::types::{ContentPart, Message, MessageContent, Role};
|
||||||
use tools::{ToolCall, FunctionCall, summarize_args};
|
use context::{ConversationEntry, ContextState, ContextBudget};
|
||||||
|
use tools::{summarize_args, working_stack};
|
||||||
|
|
||||||
use crate::user::log::ConversationLog;
|
use crate::user::log::ConversationLog;
|
||||||
use crate::agent::api::types::*;
|
|
||||||
use crate::agent::context::{ConversationEntry, ContextState, ContextBudget};
|
|
||||||
use crate::agent::tools::working_stack;
|
|
||||||
use crate::user::ui_channel::{ContextSection, SharedContextState, StreamTarget, StatusInfo, UiMessage, UiSender};
|
use crate::user::ui_channel::{ContextSection, SharedContextState, StreamTarget, StatusInfo, UiMessage, UiSender};
|
||||||
|
|
||||||
/// Result of a single agent turn.
|
/// Result of a single agent turn.
|
||||||
|
|
@ -104,7 +102,7 @@ pub struct Agent {
|
||||||
pub active_tools: crate::user::ui_channel::SharedActiveTools,
|
pub active_tools: crate::user::ui_channel::SharedActiveTools,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_journal(entries: &[journal::JournalEntry]) -> String {
|
fn render_journal(entries: &[context::JournalEntry]) -> String {
|
||||||
if entries.is_empty() { return String::new(); }
|
if entries.is_empty() { return String::new(); }
|
||||||
let mut text = String::from("[Earlier — from your journal]\n\n");
|
let mut text = String::from("[Earlier — from your journal]\n\n");
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
|
|
@ -316,112 +314,13 @@ impl Agent {
|
||||||
// --- Lock released ---
|
// --- Lock released ---
|
||||||
|
|
||||||
// --- Stream loop (no lock) ---
|
// --- Stream loop (no lock) ---
|
||||||
let mut content = String::new();
|
let sr = api::collect_stream(
|
||||||
let mut tool_calls: Vec<ToolCall> = Vec::new();
|
&mut rx, ui_tx, target, &agent, &active_tools,
|
||||||
let mut usage = None;
|
).await;
|
||||||
let mut finish_reason = None;
|
let api::StreamResult {
|
||||||
let mut in_tool_call = false;
|
content, tool_calls, usage, finish_reason,
|
||||||
let mut tool_call_buf = String::new();
|
error: stream_error, display_buf, in_tool_call,
|
||||||
let mut stream_error = None;
|
} = sr;
|
||||||
let mut first_content = true;
|
|
||||||
let mut display_buf = String::new();
|
|
||||||
|
|
||||||
while let Some(event) = rx.recv().await {
|
|
||||||
match event {
|
|
||||||
StreamEvent::Content(text) => {
|
|
||||||
if first_content {
|
|
||||||
let _ = ui_tx.send(UiMessage::Activity("streaming...".into()));
|
|
||||||
first_content = false;
|
|
||||||
}
|
|
||||||
content.push_str(&text);
|
|
||||||
|
|
||||||
if in_tool_call {
|
|
||||||
tool_call_buf.push_str(&text);
|
|
||||||
if let Some(end) = tool_call_buf.find("</tool_call>") {
|
|
||||||
let body = &tool_call_buf[..end];
|
|
||||||
if let Some(call) = crate::agent::api::parsing::parse_tool_call_body(body) {
|
|
||||||
let args: serde_json::Value =
|
|
||||||
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
|
||||||
let args_summary = summarize_args(&call.function.name, &args);
|
|
||||||
let _ = ui_tx.send(UiMessage::ToolCall {
|
|
||||||
name: call.function.name.clone(),
|
|
||||||
args_summary: args_summary.clone(),
|
|
||||||
});
|
|
||||||
let is_background = args.get("run_in_background")
|
|
||||||
.and_then(|v| v.as_bool())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let call_id = call.id.clone();
|
|
||||||
let call_name = call.function.name.clone();
|
|
||||||
let agent_handle = agent.clone();
|
|
||||||
let handle = tokio::spawn(async move {
|
|
||||||
let output = tools::dispatch_with_agent(&call.function.name, &args, Some(agent_handle)).await;
|
|
||||||
(call, output)
|
|
||||||
});
|
|
||||||
active_tools.lock().unwrap().push(
|
|
||||||
tools::ActiveToolCall {
|
|
||||||
id: call_id,
|
|
||||||
name: call_name,
|
|
||||||
detail: args_summary,
|
|
||||||
started: std::time::Instant::now(),
|
|
||||||
background: is_background,
|
|
||||||
handle,
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let remaining = tool_call_buf[end + "</tool_call>".len()..].to_string();
|
|
||||||
tool_call_buf.clear();
|
|
||||||
in_tool_call = false;
|
|
||||||
if !remaining.trim().is_empty() {
|
|
||||||
display_buf.push_str(&remaining);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
display_buf.push_str(&text);
|
|
||||||
if let Some(pos) = display_buf.find("<tool_call>") {
|
|
||||||
let before = &display_buf[..pos];
|
|
||||||
if !before.is_empty() {
|
|
||||||
let _ = ui_tx.send(UiMessage::TextDelta(before.to_string(), target));
|
|
||||||
}
|
|
||||||
display_buf.clear();
|
|
||||||
in_tool_call = true;
|
|
||||||
} else {
|
|
||||||
let safe = display_buf.len().saturating_sub(10);
|
|
||||||
let safe = display_buf.floor_char_boundary(safe);
|
|
||||||
if safe > 0 {
|
|
||||||
let flush = display_buf[..safe].to_string();
|
|
||||||
display_buf = display_buf[safe..].to_string();
|
|
||||||
let _ = ui_tx.send(UiMessage::TextDelta(flush, target));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StreamEvent::Reasoning(text) => {
|
|
||||||
let _ = ui_tx.send(UiMessage::Reasoning(text));
|
|
||||||
}
|
|
||||||
StreamEvent::ToolCallDelta { index, id, call_type, name, arguments } => {
|
|
||||||
while tool_calls.len() <= index {
|
|
||||||
tool_calls.push(ToolCall {
|
|
||||||
id: String::new(),
|
|
||||||
call_type: "function".to_string(),
|
|
||||||
function: FunctionCall { name: String::new(), arguments: String::new() },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if let Some(id) = id { tool_calls[index].id = id; }
|
|
||||||
if let Some(ct) = call_type { tool_calls[index].call_type = ct; }
|
|
||||||
if let Some(n) = name { tool_calls[index].function.name = n; }
|
|
||||||
if let Some(a) = arguments { tool_calls[index].function.arguments.push_str(&a); }
|
|
||||||
}
|
|
||||||
StreamEvent::Usage(u) => usage = Some(u),
|
|
||||||
StreamEvent::Finished { reason, .. } => {
|
|
||||||
finish_reason = Some(reason);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
StreamEvent::Error(e) => {
|
|
||||||
stream_error = Some(e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// --- Stream complete ---
|
// --- Stream complete ---
|
||||||
|
|
||||||
// --- Lock 3: process results ---
|
// --- Lock 3: process results ---
|
||||||
|
|
@ -918,7 +817,7 @@ impl Agent {
|
||||||
if total_tokens + tokens > journal_budget && !entries.is_empty() {
|
if total_tokens + tokens > journal_budget && !entries.is_empty() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
entries.push(journal::JournalEntry {
|
entries.push(context::JournalEntry {
|
||||||
timestamp: chrono::DateTime::from_timestamp(node.created_at, 0)
|
timestamp: chrono::DateTime::from_timestamp(node.created_at, 0)
|
||||||
.unwrap_or_default(),
|
.unwrap_or_default(),
|
||||||
content: node.content.clone(),
|
content: node.content.clone(),
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ impl Tool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-export API wire types used by the agent turn loop
|
// Re-export API wire types used by the agent turn loop
|
||||||
pub use super::api::types::{FunctionCall, ToolCall, ToolCallDelta};
|
use super::api::types::ToolCall;
|
||||||
|
|
||||||
/// A tool call in flight — metadata for TUI + JoinHandle for
|
/// A tool call in flight — metadata for TUI + JoinHandle for
|
||||||
/// result collection and cancellation.
|
/// result collection and cancellation.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue