diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index a3c73a0..7c06fa7 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -45,7 +45,7 @@ pub(crate) struct SamplingParams { /// One token from the streaming completions API. pub enum StreamToken { - Token { text: String, id: u32 }, + Token(u32), Done { usage: Option }, Error(String), } @@ -159,20 +159,19 @@ async fn stream_completions( }; for choice in choices { - let text = choice["text"].as_str().unwrap_or(""); - let token_ids = choice["token_ids"].as_array(); - - if let Some(ids) = token_ids { - for (i, id_val) in ids.iter().enumerate() { + if let Some(ids) = choice["token_ids"].as_array() { + for id_val in ids { if let Some(id) = id_val.as_u64() { - let _ = tx.send(StreamToken::Token { - text: if i == 0 { text.to_string() } else { String::new() }, - id: id as u32, - }); + let _ = tx.send(StreamToken::Token(id as u32)); + } + } + } else if let Some(text) = choice["text"].as_str() { + // Fallback: provider didn't return token_ids, encode locally + if !text.is_empty() { + for id in super::tokenizer::encode(text) { + let _ = tx.send(StreamToken::Token(id)); } } - } else if !text.is_empty() { - let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 }); } } } diff --git a/src/agent/context.rs b/src/agent/context.rs index 5064405..93ef607 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -444,47 +444,57 @@ fn format_tool_call_xml(name: &str, args_json: &str) -> String { xml } -fn normalize_xml_tags(text: &str) -> String { - let mut result = String::with_capacity(text.len()); - let mut chars = text.chars().peekable(); - let mut in_tag = false; - - while let Some(ch) = chars.next() { - if ch == '<' { - in_tag = true; - result.push(ch); - } else if ch == '>' { - result.push(ch); - in_tag = false; - } else if in_tag && ch.is_whitespace() { - // Skip whitespace inside tags (between < and >) - continue; - } else { - result.push(ch); +/// Search for a sequence of literal parts separated by optional ASCII whitespace. +/// Returns (start, end) byte positions of the overall match. +/// +/// Handles the case where streaming tokenization inserts whitespace inside +/// XML tag structure, e.g. `< function = bash >` instead of ``. +fn find_ws_seq(s: &str, parts: &[&str]) -> Option<(usize, usize)> { + let bytes = s.as_bytes(); + let mut search_from = 0; + 'outer: loop { + let start = s[search_from..].find(parts[0])? + search_from; + let mut pos = start + parts[0].len(); + for &part in &parts[1..] { + while pos < bytes.len() && bytes[pos].is_ascii_whitespace() { + pos += 1; + } + if !s[pos..].starts_with(part) { + search_from = start + 1; + continue 'outer; + } + pos += part.len(); } + return Some((start, pos)); } - result } +/// Parse a Qwen-style XML tag: `body`. +/// Tolerates whitespace inside tag delimiters (streaming artifact). +/// Body content is returned verbatim except for a single leading/trailing +/// newline (XML formatting convention). fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> { - let open = format!("<{}=", tag); + // Open tag: tolerate whitespace from streaming tokenization + let (_, after_eq) = find_ws_seq(s, &["<", tag, "="])?; + let gt_offset = s[after_eq..].find('>')?; + let name = s[after_eq..after_eq + gt_offset].trim(); + let body_start = after_eq + gt_offset + 1; + + // Close tag: exact match — model doesn't insert whitespace in close tags let close = format!("", tag); + let close_offset = s[body_start..].find(&close)?; + let body = &s[body_start..body_start + close_offset]; + // Strip the single leading/trailing newline from XML formatting, + // but preserve all other whitespace (indentation matters for code). + let body = body.strip_prefix('\n').unwrap_or(body); + let body = body.strip_suffix('\n').unwrap_or(body); + let rest = &s[body_start + close_offset + close.len()..]; - let start = s.find(&open)? + open.len(); - let name_end = start + s[start..].find('>')?; - let body_start = name_end + 1; - let body_end = body_start + s[body_start..].find(&close)?; - - Some(( - s[start..name_end].trim(), - s[body_start..body_end].trim(), - &s[body_end + close.len()..], - )) + Some((name, body, rest)) } fn parse_tool_call_body(body: &str) -> Option<(String, String)> { - let normalized = normalize_xml_tags(body); - let body = normalized.trim(); + let body = body.trim(); parse_xml_tool_call(body) .or_else(|| parse_json_tool_call(body)) } @@ -509,6 +519,38 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> { Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default())) } +/// Search `buf` for `close_tag`. If found, append everything before it to +/// `accum`, advance `buf` past the tag, and return the accumulated content. +/// If not found, drain the safe prefix (preserving any partial tag match at +/// the end of buf) into `accum`. +fn scan_close_tag(buf: &mut String, close_tag: &str, accum: &mut String) -> Option { + if let Some(pos) = buf.find(close_tag) { + accum.push_str(&buf[..pos]); + *buf = buf[pos + close_tag.len()..].to_string(); + Some(std::mem::take(accum)) + } else { + let drained = drain_safe(buf, close_tag.len()); + if !drained.is_empty() { + accum.push_str(&drained); + } + None + } +} + +/// Remove everything from `buf` except the last `tag_len` bytes, which might +/// be a partial tag. Returns the removed prefix. +fn drain_safe(buf: &mut String, tag_len: usize) -> String { + let safe = buf.len().saturating_sub(tag_len); + if safe > 0 { + let safe = buf.floor_char_boundary(safe); + let drained = buf[..safe].to_string(); + *buf = buf[safe..].to_string(); + drained + } else { + String::new() + } +} + impl ResponseParser { pub fn new(branch_idx: usize) -> Self { Self { @@ -544,10 +586,11 @@ impl ResponseParser { let mut full_text = String::new(); while let Some(event) = stream.recv().await { match event { - super::api::StreamToken::Token { text, id } => { + super::api::StreamToken::Token(id) => { + let text = super::tokenizer::decode(&[id]); full_text.push_str(&text); let mut ctx = agent.context.lock().await; - let calls = parser.feed_token(&text, id, &mut ctx); + let calls = parser.feed_token(&text, &mut ctx); if !calls.is_empty() { if let Some(ref mut f) = log_file { use std::io::Write; @@ -596,97 +639,72 @@ impl ResponseParser { (rx, handle) } - pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec { + pub fn feed_token(&mut self, text: &str, ctx: &mut ContextState) -> Vec { + const THINK_OPEN: &str = ""; + const THINK_CLOSE: &str = ""; + const TOOL_CALL_OPEN: &str = ""; + const TOOL_CALL_CLOSE: &str = ""; + const OPEN_TAGS: &[&str] = &[THINK_OPEN, TOOL_CALL_OPEN]; + let mut pending = Vec::new(); self.buf.push_str(text); loop { if self.in_think { - match self.buf.find("") { - Some(end) => { - self.think_buf.push_str(&self.buf[..end]); - self.buf = self.buf[end + 8..].to_string(); - self.in_think = false; - let text = std::mem::take(&mut self.think_buf).trim().to_string(); - if !text.is_empty() { - self.push_child(ctx, AstNode::thinking(text)); - } - continue; - } - None => { - let safe = self.buf.len().saturating_sub(8); - if safe > 0 { - let safe = self.buf.floor_char_boundary(safe); - self.think_buf.push_str(&self.buf[..safe]); - self.buf = self.buf[safe..].to_string(); - } - break; + if let Some(content) = scan_close_tag(&mut self.buf, THINK_CLOSE, &mut self.think_buf) { + self.in_think = false; + let text = content.trim().to_string(); + if !text.is_empty() { + self.push_child(ctx, AstNode::thinking(text)); } + continue; } + break; } if self.in_tool_call { - match self.buf.find("") { - Some(end) => { - self.tool_call_buf.push_str(&self.buf[..end]); - self.buf = self.buf[end + 12..].to_string(); - self.in_tool_call = false; - if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { - self.flush_content(ctx); - self.push_child(ctx, AstNode::tool_call(&name, &args)); - self.call_counter += 1; - pending.push(PendingToolCall { - name, - arguments: args, - id: format!("call_{}", self.call_counter), - }); - } - self.tool_call_buf.clear(); - continue; - } - None => { - let safe = self.buf.len().saturating_sub(12); - if safe > 0 { - let safe = self.buf.floor_char_boundary(safe); - self.tool_call_buf.push_str(&self.buf[..safe]); - self.buf = self.buf[safe..].to_string(); - } - break; + if let Some(content) = scan_close_tag(&mut self.buf, TOOL_CALL_CLOSE, &mut self.tool_call_buf) { + self.in_tool_call = false; + if let Some((name, args)) = parse_tool_call_body(&content) { + self.flush_content(ctx); + self.push_child(ctx, AstNode::tool_call(&name, &args)); + self.call_counter += 1; + pending.push(PendingToolCall { + name, + arguments: args, + id: format!("call_{}", self.call_counter), + }); } + continue; } + break; } - let think_pos = self.buf.find(""); - let tool_pos = self.buf.find(""); - let next_tag = match (think_pos, tool_pos) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (None, None) => None, - }; + // Not inside a tag — find the earliest opening tag + let next = OPEN_TAGS.iter() + .filter_map(|tag| self.buf.find(tag).map(|pos| (pos, *tag))) + .min_by_key(|(pos, _)| *pos); - match next_tag { - Some(pos) => { + match next { + Some((pos, tag)) => { if pos > 0 { self.content_parts.push(self.buf[..pos].to_string()); } - if self.buf[pos..].starts_with("") { - self.buf = self.buf[pos + 7..].to_string(); - self.flush_content(ctx); - self.in_think = true; - } else { - self.buf = self.buf[pos + 11..].to_string(); - self.flush_content(ctx); - self.in_tool_call = true; + self.buf = self.buf[pos + tag.len()..].to_string(); + self.flush_content(ctx); + match tag { + THINK_OPEN => self.in_think = true, + TOOL_CALL_OPEN => self.in_tool_call = true, + _ => unreachable!(), } continue; } None => { - let safe = self.buf.len().saturating_sub(11); - if safe > 0 { - let safe = self.buf.floor_char_boundary(safe); - self.content_parts.push(self.buf[..safe].to_string()); - self.buf = self.buf[safe..].to_string(); + // Keep a tail that might be a partial opening tag + let max_tag = OPEN_TAGS.iter().map(|t| t.len()).max().unwrap(); + let drained = drain_safe(&mut self.buf, max_tag); + if !drained.is_empty() { + self.content_parts.push(drained); } break; } @@ -1008,7 +1026,9 @@ mod tests { #[test] fn test_tool_call_xml_parse_streamed_whitespace() { - let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd\n"; + // Streaming tokenization can insert whitespace in opening tags, + // but close tags are always emitted verbatim. + let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd\n"; let (name, args) = parse_tool_call_body(body).unwrap(); assert_eq!(name, "bash"); let args: serde_json::Value = serde_json::from_str(&args).unwrap(); @@ -1025,15 +1045,12 @@ mod tests { } #[test] - fn test_normalize_preserves_content() { - let text = "\necho hello world\n"; - let normalized = normalize_xml_tags(text); - assert_eq!(normalized, text); - } - - #[test] - fn test_normalize_strips_tag_internal_whitespace() { - assert_eq!(normalize_xml_tags("<\nfunction\n=\nbash\n>"), ""); + fn test_tool_call_preserves_code_with_angle_brackets() { + let body = "\nif x < y {\n std::mem::swap(&mut a, &mut b);\n}\n"; + let (name, args) = parse_tool_call_body(body).unwrap(); + assert_eq!(name, "edit"); + let args: serde_json::Value = serde_json::from_str(&args).unwrap(); + assert_eq!(args["code"], "if x < y {\n std::mem::swap(&mut a, &mut b);\n}"); } // -- ResponseParser tests ------------------------------------------------- @@ -1047,7 +1064,7 @@ mod tests { let mut calls = Vec::new(); for chunk in chunks { // Feed each chunk as a single token (id=0 for tests) - calls.extend(p.feed_token(chunk, 0, &mut ctx)); + calls.extend(p.feed_token(chunk, &mut ctx)); } p.finish(&mut ctx); (ctx, calls) @@ -1109,7 +1126,7 @@ mod tests { ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![])); let mut p = ResponseParser::new(0); for ch in text.chars() { - p.feed_token(&ch.to_string(), 0, &mut ctx); + p.feed_token(&ch.to_string(), &mut ctx); } p.finish(&mut ctx); let b = bodies(assistant_children(&ctx)); @@ -1126,7 +1143,7 @@ mod tests { let mut p = ResponseParser::new(0); let mut tool_calls = 0; for ch in text.chars() { - tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len(); + tool_calls += p.feed_token(&ch.to_string(), &mut ctx).len(); } p.finish(&mut ctx); assert_eq!(tool_calls, 1);