Parser consumes stream directly, yields tool calls via channel
ResponseParser::run() spawns a task that reads StreamTokens, parses into the AST (locking context per token), and sends PendingToolCalls through a channel. Returns (tool_rx, JoinHandle<Result>) — the turn loop dispatches tool calls and awaits the handle for error checking. Token IDs from vLLM are accumulated alongside text and stored directly on AST leaves — no local re-encoding on the response path. The turn loop no longer matches on individual stream events. It just reads tool calls and dispatches them. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
0b9813431a
commit
2c401e24d6
3 changed files with 119 additions and 85 deletions
106
src/agent/mod.rs
106
src/agent/mod.rs
|
|
@ -339,77 +339,55 @@ impl Agent {
|
|||
AstNode::branch(Role::Assistant, vec![]));
|
||||
idx
|
||||
};
|
||||
let mut parser = ResponseParser::new(branch_idx);
|
||||
let mut pending_calls: Vec<PendingToolCall> = Vec::new();
|
||||
let mut had_content = false;
|
||||
let mut stream_error: Option<String> = None;
|
||||
|
||||
// Stream loop — no lock held across I/O
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
api::StreamToken::Token { text, id: _ } => {
|
||||
had_content = true;
|
||||
let mut ctx = agent.context.lock().await;
|
||||
let calls = parser.feed(&text, &mut ctx);
|
||||
drop(ctx);
|
||||
for call in calls {
|
||||
let call_clone = call.clone();
|
||||
let agent_handle = agent.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&call_clone.arguments).unwrap_or_default();
|
||||
let output = tools::dispatch_with_agent(
|
||||
&call_clone.name, &args, Some(agent_handle),
|
||||
).await;
|
||||
(call_clone, output)
|
||||
});
|
||||
active_tools.lock().unwrap().push(tools::ActiveToolCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
detail: call.arguments.clone(),
|
||||
started: std::time::Instant::now(),
|
||||
background: false,
|
||||
handle,
|
||||
});
|
||||
pending_calls.push(call);
|
||||
}
|
||||
}
|
||||
api::StreamToken::Error(e) => {
|
||||
stream_error = Some(e);
|
||||
break;
|
||||
}
|
||||
api::StreamToken::Done { usage } => {
|
||||
if let Some(u) = usage {
|
||||
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
let parser = ResponseParser::new(branch_idx);
|
||||
let (mut tool_rx, parser_handle) = parser.run(rx, agent.clone());
|
||||
|
||||
let mut pending_calls: Vec<PendingToolCall> = Vec::new();
|
||||
while let Some(call) = tool_rx.recv().await {
|
||||
let call_clone = call.clone();
|
||||
let agent_handle = agent.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&call_clone.arguments).unwrap_or_default();
|
||||
let output = tools::dispatch_with_agent(
|
||||
&call_clone.name, &args, Some(agent_handle),
|
||||
).await;
|
||||
(call_clone, output)
|
||||
});
|
||||
active_tools.lock().unwrap().push(tools::ActiveToolCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
detail: call.arguments.clone(),
|
||||
started: std::time::Instant::now(),
|
||||
background: false,
|
||||
handle,
|
||||
});
|
||||
pending_calls.push(call);
|
||||
}
|
||||
|
||||
// Flush parser remainder
|
||||
parser.finish(&mut *agent.context.lock().await);
|
||||
|
||||
// Handle errors
|
||||
if let Some(e) = stream_error {
|
||||
let err = anyhow::anyhow!("{}", e);
|
||||
if context::is_context_overflow(&err) && overflow_retries < 2 {
|
||||
overflow_retries += 1;
|
||||
agent.state.lock().await.notify(format!("context overflow — retrying ({}/2)", overflow_retries));
|
||||
agent.compact().await;
|
||||
continue;
|
||||
// Check for stream/parse errors
|
||||
match parser_handle.await {
|
||||
Ok(Err(e)) => {
|
||||
if context::is_context_overflow(&e) && overflow_retries < 2 {
|
||||
overflow_retries += 1;
|
||||
agent.state.lock().await.notify(
|
||||
format!("context overflow — retrying ({}/2)", overflow_retries));
|
||||
agent.compact().await;
|
||||
continue;
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
if context::is_stream_error(&err) && empty_retries < 2 {
|
||||
empty_retries += 1;
|
||||
agent.state.lock().await.notify(format!("stream error — retrying ({}/2)", empty_retries));
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
continue;
|
||||
}
|
||||
return Err(err);
|
||||
Err(e) => return Err(anyhow::anyhow!("parser task panicked: {}", e)),
|
||||
Ok(Ok(())) => {}
|
||||
}
|
||||
|
||||
// Empty response — nudge and retry
|
||||
if !had_content && pending_calls.is_empty() {
|
||||
let has_content = {
|
||||
let ctx = agent.context.lock().await;
|
||||
!ctx.conversation()[branch_idx].children().is_empty()
|
||||
};
|
||||
if !has_content && pending_calls.is_empty() {
|
||||
if empty_retries < 2 {
|
||||
empty_retries += 1;
|
||||
agent.push_node(AstNode::user_msg(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue