diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index dc9f0fd..c0e0f6e 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -8,14 +8,13 @@ pub mod http; -use anyhow::Result; use std::time::{Duration, Instant}; - -use self::http::{HttpClient, HttpResponse}; - +use anyhow::Result; use tokio::sync::mpsc; use serde::Deserialize; +use http::{HttpClient, HttpResponse}; + #[derive(Debug, Clone, Deserialize)] pub struct Usage { pub prompt_tokens: u32, diff --git a/src/agent/context.rs b/src/agent/context.rs index e0d05f9..43d5f2f 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -115,11 +115,15 @@ pub struct ResponseParser { branch_idx: usize, call_counter: u32, buf: String, + buf_token_ids: Vec, content_parts: Vec, + content_token_ids: Vec, in_think: bool, think_buf: String, + think_token_ids: Vec, in_tool_call: bool, tool_call_buf: String, + tool_call_token_ids: Vec, } impl Role { @@ -462,36 +466,80 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> { } impl ResponseParser { - /// Create a parser that pushes children into the assistant branch - /// at `branch_idx` in the conversation section. pub fn new(branch_idx: usize) -> Self { Self { branch_idx, call_counter: 0, buf: String::new(), + buf_token_ids: Vec::new(), content_parts: Vec::new(), + content_token_ids: Vec::new(), in_think: false, think_buf: String::new(), + think_token_ids: Vec::new(), in_tool_call: false, tool_call_buf: String::new(), + tool_call_token_ids: Vec::new(), } } - /// Feed a text chunk. Completed children are pushed directly into - /// the AST. Returns any tool calls that need dispatching. - pub fn feed(&mut self, text: &str, ctx: &mut ContextState) -> Vec { + /// Consume a token stream, parse into the AST, yield tool calls. + /// Spawns a background task. Returns a tool call receiver and a + /// join handle that resolves to Ok(()) or the stream error. + pub fn run( + self, + mut stream: tokio::sync::mpsc::UnboundedReceiver, + agent: std::sync::Arc, + ) -> ( + tokio::sync::mpsc::UnboundedReceiver, + tokio::task::JoinHandle>, + ) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let handle = tokio::spawn(async move { + let mut parser = self; + while let Some(event) = stream.recv().await { + match event { + super::api::StreamToken::Token { text, id } => { + let mut ctx = agent.context.lock().await; + for call in parser.feed_token(&text, id, &mut ctx) { + let _ = tx.send(call); + } + } + super::api::StreamToken::Done { usage } => { + if let Some(u) = usage { + agent.state.lock().await.last_prompt_tokens = u.prompt_tokens; + } + let mut ctx = agent.context.lock().await; + parser.finish(&mut ctx); + return Ok(()); + } + super::api::StreamToken::Error(e) => { + return Err(anyhow::anyhow!("{}", e)); + } + } + } + Ok(()) + }); + (rx, handle) + } + + pub fn feed_token(&mut self, text: &str, token_id: u32, ctx: &mut ContextState) -> Vec { let mut pending = Vec::new(); self.buf.push_str(text); + self.buf_token_ids.push(token_id); loop { if self.in_think { match self.buf.find("") { Some(end) => { self.think_buf.push_str(&self.buf[..end]); + // Token IDs: move all buffered IDs to think (approximate split) + self.think_token_ids.extend(self.buf_token_ids.drain(..)); self.buf = self.buf[end + 8..].to_string(); self.in_think = false; - self.push_child(ctx, AstNode::thinking(&self.think_buf)); - self.think_buf.clear(); + let text = std::mem::take(&mut self.think_buf); + let ids = std::mem::take(&mut self.think_token_ids); + self.push_child_with_tokens(ctx, NodeBody::Thinking(text), ids); continue; } None => { @@ -500,6 +548,7 @@ impl ResponseParser { let safe = self.buf.floor_char_boundary(safe); self.think_buf.push_str(&self.buf[..safe]); self.buf = self.buf[safe..].to_string(); + // Keep token IDs in buf (lookahead) } break; } @@ -510,10 +559,12 @@ impl ResponseParser { match self.buf.find("") { Some(end) => { self.tool_call_buf.push_str(&self.buf[..end]); + self.tool_call_token_ids.extend(self.buf_token_ids.drain(..)); self.buf = self.buf[end + 12..].to_string(); self.in_tool_call = false; if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { self.flush_content(ctx); + // Tool calls get re-tokenized from structured data self.push_child(ctx, AstNode::tool_call(&name, &args)); self.call_counter += 1; pending.push(PendingToolCall { @@ -523,6 +574,7 @@ impl ResponseParser { }); } self.tool_call_buf.clear(); + self.tool_call_token_ids.clear(); continue; } None => { @@ -551,6 +603,8 @@ impl ResponseParser { if pos > 0 { self.content_parts.push(self.buf[..pos].to_string()); } + // Move token IDs to content accumulator + self.content_token_ids.extend(self.buf_token_ids.drain(..)); if self.buf[pos..].starts_with("") { self.buf = self.buf[pos + 7..].to_string(); self.flush_content(ctx); @@ -568,6 +622,7 @@ impl ResponseParser { let safe = self.buf.floor_char_boundary(safe); self.content_parts.push(self.buf[..safe].to_string()); self.buf = self.buf[safe..].to_string(); + // Keep token IDs in buf (lookahead) } break; } @@ -581,27 +636,28 @@ impl ResponseParser { ctx.push_child(Section::Conversation, self.branch_idx, child); } + fn push_child_with_tokens(&self, ctx: &mut ContextState, body: NodeBody, token_ids: Vec) { + let leaf = NodeLeaf { body, token_ids, timestamp: None }; + ctx.push_child(Section::Conversation, self.branch_idx, AstNode::Leaf(leaf)); + } + fn flush_content(&mut self, ctx: &mut ContextState) { if !self.content_parts.is_empty() { let text: String = self.content_parts.drain(..).collect(); if !text.is_empty() { - self.push_child(ctx, AstNode::content(text)); + let token_ids = std::mem::take(&mut self.content_token_ids); + self.push_child_with_tokens(ctx, NodeBody::Content(text), token_ids); } } } - /// Flush remaining buffer into the AST. pub fn finish(mut self, ctx: &mut ContextState) { if !self.buf.is_empty() { self.content_parts.push(std::mem::take(&mut self.buf)); + self.content_token_ids.extend(self.buf_token_ids.drain(..)); } self.flush_content(ctx); } - - /// Current display text (content accumulated since last drain). - pub fn display_content(&self) -> String { - self.content_parts.join("") - } } impl ContextState { @@ -838,7 +894,8 @@ mod tests { let mut p = ResponseParser::new(0); let mut calls = Vec::new(); for chunk in chunks { - calls.extend(p.feed(chunk, &mut ctx)); + // Feed each chunk as a single token (id=0 for tests) + calls.extend(p.feed_token(chunk, 0, &mut ctx)); } p.finish(&mut ctx); (ctx, calls) @@ -900,7 +957,7 @@ mod tests { ctx.push(Section::Conversation, AstNode::branch(Role::Assistant, vec![])); let mut p = ResponseParser::new(0); for ch in text.chars() { - p.feed(&ch.to_string(), &mut ctx); + p.feed_token(&ch.to_string(), 0, &mut ctx); } p.finish(&mut ctx); let b = bodies(assistant_children(&ctx)); @@ -917,7 +974,7 @@ mod tests { let mut p = ResponseParser::new(0); let mut tool_calls = 0; for ch in text.chars() { - tool_calls += p.feed(&ch.to_string(), &mut ctx).len(); + tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len(); } p.finish(&mut ctx); assert_eq!(tool_calls, 1); diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 5e67dc7..0c0e7f3 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -339,77 +339,55 @@ impl Agent { AstNode::branch(Role::Assistant, vec![])); idx }; - let mut parser = ResponseParser::new(branch_idx); - let mut pending_calls: Vec = Vec::new(); - let mut had_content = false; - let mut stream_error: Option = None; - // Stream loop — no lock held across I/O - while let Some(event) = rx.recv().await { - match event { - api::StreamToken::Token { text, id: _ } => { - had_content = true; - let mut ctx = agent.context.lock().await; - let calls = parser.feed(&text, &mut ctx); - drop(ctx); - for call in calls { - let call_clone = call.clone(); - let agent_handle = agent.clone(); - let handle = tokio::spawn(async move { - let args: serde_json::Value = - serde_json::from_str(&call_clone.arguments).unwrap_or_default(); - let output = tools::dispatch_with_agent( - &call_clone.name, &args, Some(agent_handle), - ).await; - (call_clone, output) - }); - active_tools.lock().unwrap().push(tools::ActiveToolCall { - id: call.id.clone(), - name: call.name.clone(), - detail: call.arguments.clone(), - started: std::time::Instant::now(), - background: false, - handle, - }); - pending_calls.push(call); - } - } - api::StreamToken::Error(e) => { - stream_error = Some(e); - break; - } - api::StreamToken::Done { usage } => { - if let Some(u) = usage { - agent.state.lock().await.last_prompt_tokens = u.prompt_tokens; - } - break; - } - } + let parser = ResponseParser::new(branch_idx); + let (mut tool_rx, parser_handle) = parser.run(rx, agent.clone()); + + let mut pending_calls: Vec = Vec::new(); + while let Some(call) = tool_rx.recv().await { + let call_clone = call.clone(); + let agent_handle = agent.clone(); + let handle = tokio::spawn(async move { + let args: serde_json::Value = + serde_json::from_str(&call_clone.arguments).unwrap_or_default(); + let output = tools::dispatch_with_agent( + &call_clone.name, &args, Some(agent_handle), + ).await; + (call_clone, output) + }); + active_tools.lock().unwrap().push(tools::ActiveToolCall { + id: call.id.clone(), + name: call.name.clone(), + detail: call.arguments.clone(), + started: std::time::Instant::now(), + background: false, + handle, + }); + pending_calls.push(call); } - // Flush parser remainder - parser.finish(&mut *agent.context.lock().await); - - // Handle errors - if let Some(e) = stream_error { - let err = anyhow::anyhow!("{}", e); - if context::is_context_overflow(&err) && overflow_retries < 2 { - overflow_retries += 1; - agent.state.lock().await.notify(format!("context overflow — retrying ({}/2)", overflow_retries)); - agent.compact().await; - continue; + // Check for stream/parse errors + match parser_handle.await { + Ok(Err(e)) => { + if context::is_context_overflow(&e) && overflow_retries < 2 { + overflow_retries += 1; + agent.state.lock().await.notify( + format!("context overflow — retrying ({}/2)", overflow_retries)); + agent.compact().await; + continue; + } + return Err(e); } - if context::is_stream_error(&err) && empty_retries < 2 { - empty_retries += 1; - agent.state.lock().await.notify(format!("stream error — retrying ({}/2)", empty_retries)); - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - continue; - } - return Err(err); + Err(e) => return Err(anyhow::anyhow!("parser task panicked: {}", e)), + Ok(Ok(())) => {} } // Empty response — nudge and retry - if !had_content && pending_calls.is_empty() { + let has_content = { + let ctx = agent.context.lock().await; + !ctx.conversation()[branch_idx].children().is_empty() + }; + if !has_content && pending_calls.is_empty() { if empty_retries < 2 { empty_retries += 1; agent.push_node(AstNode::user_msg(