Parsing fixes

Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
Kent Overstreet 2026-04-09 16:20:11 -04:00
parent b55230ce3f
commit 0af97774f4
2 changed files with 143 additions and 127 deletions

View file

@ -45,7 +45,7 @@ pub(crate) struct SamplingParams {
/// One token from the streaming completions API. /// One token from the streaming completions API.
pub enum StreamToken { pub enum StreamToken {
Token { text: String, id: u32 }, Token(u32),
Done { usage: Option<Usage> }, Done { usage: Option<Usage> },
Error(String), Error(String),
} }
@ -159,20 +159,19 @@ async fn stream_completions(
}; };
for choice in choices { for choice in choices {
let text = choice["text"].as_str().unwrap_or(""); if let Some(ids) = choice["token_ids"].as_array() {
let token_ids = choice["token_ids"].as_array(); for id_val in ids {
if let Some(ids) = token_ids {
for (i, id_val) in ids.iter().enumerate() {
if let Some(id) = id_val.as_u64() { if let Some(id) = id_val.as_u64() {
let _ = tx.send(StreamToken::Token { let _ = tx.send(StreamToken::Token(id as u32));
text: if i == 0 { text.to_string() } else { String::new() }, }
id: id as u32, }
}); } else if let Some(text) = choice["text"].as_str() {
// Fallback: provider didn't return token_ids, encode locally
if !text.is_empty() {
for id in super::tokenizer::encode(text) {
let _ = tx.send(StreamToken::Token(id));
} }
} }
} else if !text.is_empty() {
let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 });
} }
} }
} }

View file

@ -444,47 +444,57 @@ fn format_tool_call_xml(name: &str, args_json: &str) -> String {
xml xml
} }
fn normalize_xml_tags(text: &str) -> String { /// Search for a sequence of literal parts separated by optional ASCII whitespace.
let mut result = String::with_capacity(text.len()); /// Returns (start, end) byte positions of the overall match.
let mut chars = text.chars().peekable(); ///
let mut in_tag = false; /// Handles the case where streaming tokenization inserts whitespace inside
/// XML tag structure, e.g. `< function = bash >` instead of `<function=bash>`.
while let Some(ch) = chars.next() { fn find_ws_seq(s: &str, parts: &[&str]) -> Option<(usize, usize)> {
if ch == '<' { let bytes = s.as_bytes();
in_tag = true; let mut search_from = 0;
result.push(ch); 'outer: loop {
} else if ch == '>' { let start = s[search_from..].find(parts[0])? + search_from;
result.push(ch); let mut pos = start + parts[0].len();
in_tag = false; for &part in &parts[1..] {
} else if in_tag && ch.is_whitespace() { while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
// Skip whitespace inside tags (between < and >) pos += 1;
continue;
} else {
result.push(ch);
} }
if !s[pos..].starts_with(part) {
search_from = start + 1;
continue 'outer;
}
pos += part.len();
}
return Some((start, pos));
} }
result
} }
/// Parse a Qwen-style XML tag: `<tag=name>body</tag>`.
/// Tolerates whitespace inside tag delimiters (streaming artifact).
/// Body content is returned verbatim except for a single leading/trailing
/// newline (XML formatting convention).
fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> { fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> {
let open = format!("<{}=", tag); // Open tag: tolerate whitespace from streaming tokenization
let (_, after_eq) = find_ws_seq(s, &["<", tag, "="])?;
let gt_offset = s[after_eq..].find('>')?;
let name = s[after_eq..after_eq + gt_offset].trim();
let body_start = after_eq + gt_offset + 1;
// Close tag: exact match — model doesn't insert whitespace in close tags
let close = format!("</{}>", tag); let close = format!("</{}>", tag);
let close_offset = s[body_start..].find(&close)?;
let body = &s[body_start..body_start + close_offset];
// Strip the single leading/trailing newline from XML formatting,
// but preserve all other whitespace (indentation matters for code).
let body = body.strip_prefix('\n').unwrap_or(body);
let body = body.strip_suffix('\n').unwrap_or(body);
let rest = &s[body_start + close_offset + close.len()..];
let start = s.find(&open)? + open.len(); Some((name, body, rest))
let name_end = start + s[start..].find('>')?;
let body_start = name_end + 1;
let body_end = body_start + s[body_start..].find(&close)?;
Some((
s[start..name_end].trim(),
s[body_start..body_end].trim(),
&s[body_end + close.len()..],
))
} }
fn parse_tool_call_body(body: &str) -> Option<(String, String)> { fn parse_tool_call_body(body: &str) -> Option<(String, String)> {
let normalized = normalize_xml_tags(body); let body = body.trim();
let body = normalized.trim();
parse_xml_tool_call(body) parse_xml_tool_call(body)
.or_else(|| parse_json_tool_call(body)) .or_else(|| parse_json_tool_call(body))
} }
@ -509,6 +519,38 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default())) Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default()))
} }
/// Search `buf` for `close_tag`. If found, append everything before it to
/// `accum`, advance `buf` past the tag, and return the accumulated content.
/// If not found, drain the safe prefix (preserving any partial tag match at
/// the end of buf) into `accum`.
fn scan_close_tag(buf: &mut String, close_tag: &str, accum: &mut String) -> Option<String> {
if let Some(pos) = buf.find(close_tag) {
accum.push_str(&buf[..pos]);
*buf = buf[pos + close_tag.len()..].to_string();
Some(std::mem::take(accum))
} else {
let drained = drain_safe(buf, close_tag.len());
if !drained.is_empty() {
accum.push_str(&drained);
}
None
}
}
/// Remove everything from `buf` except the last `tag_len` bytes, which might
/// be a partial tag. Returns the removed prefix.
fn drain_safe(buf: &mut String, tag_len: usize) -> String {
let safe = buf.len().saturating_sub(tag_len);
if safe > 0 {
let safe = buf.floor_char_boundary(safe);
let drained = buf[..safe].to_string();
*buf = buf[safe..].to_string();
drained
} else {
String::new()
}
}
impl ResponseParser { impl ResponseParser {
pub fn new(branch_idx: usize) -> Self { pub fn new(branch_idx: usize) -> Self {
Self { Self {
@ -544,10 +586,11 @@ impl ResponseParser {
let mut full_text = String::new(); let mut full_text = String::new();
while let Some(event) = stream.recv().await { while let Some(event) = stream.recv().await {
match event { match event {
super::api::StreamToken::Token { text, id } => { super::api::StreamToken::Token(id) => {
let text = super::tokenizer::decode(&[id]);
full_text.push_str(&text); full_text.push_str(&text);
let mut ctx = agent.context.lock().await; let mut ctx = agent.context.lock().await;
let calls = parser.feed_token(&text, id, &mut ctx); let calls = parser.feed_token(&text, &mut ctx);
if !calls.is_empty() { if !calls.is_empty() {
if let Some(ref mut f) = log_file { if let Some(ref mut f) = log_file {
use std::io::Write; use std::io::Write;
@ -596,42 +639,33 @@ impl ResponseParser {
(rx, handle) (rx, handle)
} }
pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> { pub fn feed_token(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
const THINK_OPEN: &str = "<think>";
const THINK_CLOSE: &str = "</think>";
const TOOL_CALL_OPEN: &str = "<tool_call>";
const TOOL_CALL_CLOSE: &str = "</tool_call>";
const OPEN_TAGS: &[&str] = &[THINK_OPEN, TOOL_CALL_OPEN];
let mut pending = Vec::new(); let mut pending = Vec::new();
self.buf.push_str(text); self.buf.push_str(text);
loop { loop {
if self.in_think { if self.in_think {
match self.buf.find("</think>") { if let Some(content) = scan_close_tag(&mut self.buf, THINK_CLOSE, &mut self.think_buf) {
Some(end) => {
self.think_buf.push_str(&self.buf[..end]);
self.buf = self.buf[end + 8..].to_string();
self.in_think = false; self.in_think = false;
let text = std::mem::take(&mut self.think_buf).trim().to_string(); let text = content.trim().to_string();
if !text.is_empty() { if !text.is_empty() {
self.push_child(ctx, AstNode::thinking(text)); self.push_child(ctx, AstNode::thinking(text));
} }
continue; continue;
} }
None => {
let safe = self.buf.len().saturating_sub(8);
if safe > 0 {
let safe = self.buf.floor_char_boundary(safe);
self.think_buf.push_str(&self.buf[..safe]);
self.buf = self.buf[safe..].to_string();
}
break; break;
} }
}
}
if self.in_tool_call { if self.in_tool_call {
match self.buf.find("</tool_call>") { if let Some(content) = scan_close_tag(&mut self.buf, TOOL_CALL_CLOSE, &mut self.tool_call_buf) {
Some(end) => {
self.tool_call_buf.push_str(&self.buf[..end]);
self.buf = self.buf[end + 12..].to_string();
self.in_tool_call = false; self.in_tool_call = false;
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { if let Some((name, args)) = parse_tool_call_body(&content) {
self.flush_content(ctx); self.flush_content(ctx);
self.push_child(ctx, AstNode::tool_call(&name, &args)); self.push_child(ctx, AstNode::tool_call(&name, &args));
self.call_counter += 1; self.call_counter += 1;
@ -641,52 +675,36 @@ impl ResponseParser {
id: format!("call_{}", self.call_counter), id: format!("call_{}", self.call_counter),
}); });
} }
self.tool_call_buf.clear();
continue; continue;
} }
None => {
let safe = self.buf.len().saturating_sub(12);
if safe > 0 {
let safe = self.buf.floor_char_boundary(safe);
self.tool_call_buf.push_str(&self.buf[..safe]);
self.buf = self.buf[safe..].to_string();
}
break; break;
} }
}
}
let think_pos = self.buf.find("<think>"); // Not inside a tag — find the earliest opening tag
let tool_pos = self.buf.find("<tool_call>"); let next = OPEN_TAGS.iter()
let next_tag = match (think_pos, tool_pos) { .filter_map(|tag| self.buf.find(tag).map(|pos| (pos, *tag)))
(Some(a), Some(b)) => Some(a.min(b)), .min_by_key(|(pos, _)| *pos);
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
match next_tag { match next {
Some(pos) => { Some((pos, tag)) => {
if pos > 0 { if pos > 0 {
self.content_parts.push(self.buf[..pos].to_string()); self.content_parts.push(self.buf[..pos].to_string());
} }
if self.buf[pos..].starts_with("<think>") { self.buf = self.buf[pos + tag.len()..].to_string();
self.buf = self.buf[pos + 7..].to_string();
self.flush_content(ctx); self.flush_content(ctx);
self.in_think = true; match tag {
} else { THINK_OPEN => self.in_think = true,
self.buf = self.buf[pos + 11..].to_string(); TOOL_CALL_OPEN => self.in_tool_call = true,
self.flush_content(ctx); _ => unreachable!(),
self.in_tool_call = true;
} }
continue; continue;
} }
None => { None => {
let safe = self.buf.len().saturating_sub(11); // Keep a tail that might be a partial opening tag
if safe > 0 { let max_tag = OPEN_TAGS.iter().map(|t| t.len()).max().unwrap();
let safe = self.buf.floor_char_boundary(safe); let drained = drain_safe(&mut self.buf, max_tag);
self.content_parts.push(self.buf[..safe].to_string()); if !drained.is_empty() {
self.buf = self.buf[safe..].to_string(); self.content_parts.push(drained);
} }
break; break;
} }
@ -1008,7 +1026,9 @@ mod tests {
#[test] #[test]
fn test_tool_call_xml_parse_streamed_whitespace() { fn test_tool_call_xml_parse_streamed_whitespace() {
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</\nparameter\n>\n</\nfunction\n>"; // Streaming tokenization can insert whitespace in opening tags,
// but close tags are always emitted verbatim.
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</parameter>\n</function>";
let (name, args) = parse_tool_call_body(body).unwrap(); let (name, args) = parse_tool_call_body(body).unwrap();
assert_eq!(name, "bash"); assert_eq!(name, "bash");
let args: serde_json::Value = serde_json::from_str(&args).unwrap(); let args: serde_json::Value = serde_json::from_str(&args).unwrap();
@ -1025,15 +1045,12 @@ mod tests {
} }
#[test] #[test]
fn test_normalize_preserves_content() { fn test_tool_call_preserves_code_with_angle_brackets() {
let text = "<function=bash>\n<parameter=command>echo hello world</parameter>\n</function>"; let body = "<function=edit>\n<parameter=code>if x < y {\n std::mem::swap(&mut a, &mut b);\n}</parameter>\n</function>";
let normalized = normalize_xml_tags(text); let (name, args) = parse_tool_call_body(body).unwrap();
assert_eq!(normalized, text); assert_eq!(name, "edit");
} let args: serde_json::Value = serde_json::from_str(&args).unwrap();
assert_eq!(args["code"], "if x < y {\n std::mem::swap(&mut a, &mut b);\n}");
#[test]
fn test_normalize_strips_tag_internal_whitespace() {
assert_eq!(normalize_xml_tags("<\nfunction\n=\nbash\n>"), "<function=bash>");
} }
// -- ResponseParser tests ------------------------------------------------- // -- ResponseParser tests -------------------------------------------------
@ -1047,7 +1064,7 @@ mod tests {
let mut calls = Vec::new(); let mut calls = Vec::new();
for chunk in chunks { for chunk in chunks {
// Feed each chunk as a single token (id=0 for tests) // Feed each chunk as a single token (id=0 for tests)
calls.extend(p.feed_token(chunk, 0, &mut ctx)); calls.extend(p.feed_token(chunk, &mut ctx));
} }
p.finish(&mut ctx); p.finish(&mut ctx);
(ctx, calls) (ctx, calls)
@ -1109,7 +1126,7 @@ mod tests {
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![])); ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
let mut p = ResponseParser::new(0); let mut p = ResponseParser::new(0);
for ch in text.chars() { for ch in text.chars() {
p.feed_token(&ch.to_string(), 0, &mut ctx); p.feed_token(&ch.to_string(), &mut ctx);
} }
p.finish(&mut ctx); p.finish(&mut ctx);
let b = bodies(assistant_children(&ctx)); let b = bodies(assistant_children(&ctx));
@ -1126,7 +1143,7 @@ mod tests {
let mut p = ResponseParser::new(0); let mut p = ResponseParser::new(0);
let mut tool_calls = 0; let mut tool_calls = 0;
for ch in text.chars() { for ch in text.chars() {
tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len(); tool_calls += p.feed_token(&ch.to_string(), &mut ctx).len();
} }
p.finish(&mut ctx); p.finish(&mut ctx);
assert_eq!(tool_calls, 1); assert_eq!(tool_calls, 1);