Parsing fixes
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
b55230ce3f
commit
0af97774f4
2 changed files with 143 additions and 127 deletions
|
|
@ -444,47 +444,57 @@ fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
|||
xml
|
||||
}
|
||||
|
||||
fn normalize_xml_tags(text: &str) -> String {
|
||||
let mut result = String::with_capacity(text.len());
|
||||
let mut chars = text.chars().peekable();
|
||||
let mut in_tag = false;
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
if ch == '<' {
|
||||
in_tag = true;
|
||||
result.push(ch);
|
||||
} else if ch == '>' {
|
||||
result.push(ch);
|
||||
in_tag = false;
|
||||
} else if in_tag && ch.is_whitespace() {
|
||||
// Skip whitespace inside tags (between < and >)
|
||||
continue;
|
||||
} else {
|
||||
result.push(ch);
|
||||
/// Search for a sequence of literal parts separated by optional ASCII whitespace.
|
||||
/// Returns (start, end) byte positions of the overall match.
|
||||
///
|
||||
/// Handles the case where streaming tokenization inserts whitespace inside
|
||||
/// XML tag structure, e.g. `< function = bash >` instead of `<function=bash>`.
|
||||
fn find_ws_seq(s: &str, parts: &[&str]) -> Option<(usize, usize)> {
|
||||
let bytes = s.as_bytes();
|
||||
let mut search_from = 0;
|
||||
'outer: loop {
|
||||
let start = s[search_from..].find(parts[0])? + search_from;
|
||||
let mut pos = start + parts[0].len();
|
||||
for &part in &parts[1..] {
|
||||
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
|
||||
pos += 1;
|
||||
}
|
||||
if !s[pos..].starts_with(part) {
|
||||
search_from = start + 1;
|
||||
continue 'outer;
|
||||
}
|
||||
pos += part.len();
|
||||
}
|
||||
return Some((start, pos));
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Parse a Qwen-style XML tag: `<tag=name>body</tag>`.
|
||||
/// Tolerates whitespace inside tag delimiters (streaming artifact).
|
||||
/// Body content is returned verbatim except for a single leading/trailing
|
||||
/// newline (XML formatting convention).
|
||||
fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> {
|
||||
let open = format!("<{}=", tag);
|
||||
// Open tag: tolerate whitespace from streaming tokenization
|
||||
let (_, after_eq) = find_ws_seq(s, &["<", tag, "="])?;
|
||||
let gt_offset = s[after_eq..].find('>')?;
|
||||
let name = s[after_eq..after_eq + gt_offset].trim();
|
||||
let body_start = after_eq + gt_offset + 1;
|
||||
|
||||
// Close tag: exact match — model doesn't insert whitespace in close tags
|
||||
let close = format!("</{}>", tag);
|
||||
let close_offset = s[body_start..].find(&close)?;
|
||||
let body = &s[body_start..body_start + close_offset];
|
||||
// Strip the single leading/trailing newline from XML formatting,
|
||||
// but preserve all other whitespace (indentation matters for code).
|
||||
let body = body.strip_prefix('\n').unwrap_or(body);
|
||||
let body = body.strip_suffix('\n').unwrap_or(body);
|
||||
let rest = &s[body_start + close_offset + close.len()..];
|
||||
|
||||
let start = s.find(&open)? + open.len();
|
||||
let name_end = start + s[start..].find('>')?;
|
||||
let body_start = name_end + 1;
|
||||
let body_end = body_start + s[body_start..].find(&close)?;
|
||||
|
||||
Some((
|
||||
s[start..name_end].trim(),
|
||||
s[body_start..body_end].trim(),
|
||||
&s[body_end + close.len()..],
|
||||
))
|
||||
Some((name, body, rest))
|
||||
}
|
||||
|
||||
fn parse_tool_call_body(body: &str) -> Option<(String, String)> {
|
||||
let normalized = normalize_xml_tags(body);
|
||||
let body = normalized.trim();
|
||||
let body = body.trim();
|
||||
parse_xml_tool_call(body)
|
||||
.or_else(|| parse_json_tool_call(body))
|
||||
}
|
||||
|
|
@ -509,6 +519,38 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
|
|||
Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default()))
|
||||
}
|
||||
|
||||
/// Search `buf` for `close_tag`. If found, append everything before it to
|
||||
/// `accum`, advance `buf` past the tag, and return the accumulated content.
|
||||
/// If not found, drain the safe prefix (preserving any partial tag match at
|
||||
/// the end of buf) into `accum`.
|
||||
fn scan_close_tag(buf: &mut String, close_tag: &str, accum: &mut String) -> Option<String> {
|
||||
if let Some(pos) = buf.find(close_tag) {
|
||||
accum.push_str(&buf[..pos]);
|
||||
*buf = buf[pos + close_tag.len()..].to_string();
|
||||
Some(std::mem::take(accum))
|
||||
} else {
|
||||
let drained = drain_safe(buf, close_tag.len());
|
||||
if !drained.is_empty() {
|
||||
accum.push_str(&drained);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove everything from `buf` except the last `tag_len` bytes, which might
|
||||
/// be a partial tag. Returns the removed prefix.
|
||||
fn drain_safe(buf: &mut String, tag_len: usize) -> String {
|
||||
let safe = buf.len().saturating_sub(tag_len);
|
||||
if safe > 0 {
|
||||
let safe = buf.floor_char_boundary(safe);
|
||||
let drained = buf[..safe].to_string();
|
||||
*buf = buf[safe..].to_string();
|
||||
drained
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseParser {
|
||||
pub fn new(branch_idx: usize) -> Self {
|
||||
Self {
|
||||
|
|
@ -544,10 +586,11 @@ impl ResponseParser {
|
|||
let mut full_text = String::new();
|
||||
while let Some(event) = stream.recv().await {
|
||||
match event {
|
||||
super::api::StreamToken::Token { text, id } => {
|
||||
super::api::StreamToken::Token(id) => {
|
||||
let text = super::tokenizer::decode(&[id]);
|
||||
full_text.push_str(&text);
|
||||
let mut ctx = agent.context.lock().await;
|
||||
let calls = parser.feed_token(&text, id, &mut ctx);
|
||||
let calls = parser.feed_token(&text, &mut ctx);
|
||||
if !calls.is_empty() {
|
||||
if let Some(ref mut f) = log_file {
|
||||
use std::io::Write;
|
||||
|
|
@ -596,97 +639,72 @@ impl ResponseParser {
|
|||
(rx, handle)
|
||||
}
|
||||
|
||||
pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||
pub fn feed_token(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||
const THINK_OPEN: &str = "<think>";
|
||||
const THINK_CLOSE: &str = "</think>";
|
||||
const TOOL_CALL_OPEN: &str = "<tool_call>";
|
||||
const TOOL_CALL_CLOSE: &str = "</tool_call>";
|
||||
const OPEN_TAGS: &[&str] = &[THINK_OPEN, TOOL_CALL_OPEN];
|
||||
|
||||
let mut pending = Vec::new();
|
||||
self.buf.push_str(text);
|
||||
|
||||
loop {
|
||||
if self.in_think {
|
||||
match self.buf.find("</think>") {
|
||||
Some(end) => {
|
||||
self.think_buf.push_str(&self.buf[..end]);
|
||||
self.buf = self.buf[end + 8..].to_string();
|
||||
self.in_think = false;
|
||||
let text = std::mem::take(&mut self.think_buf).trim().to_string();
|
||||
if !text.is_empty() {
|
||||
self.push_child(ctx, AstNode::thinking(text));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
let safe = self.buf.len().saturating_sub(8);
|
||||
if safe > 0 {
|
||||
let safe = self.buf.floor_char_boundary(safe);
|
||||
self.think_buf.push_str(&self.buf[..safe]);
|
||||
self.buf = self.buf[safe..].to_string();
|
||||
}
|
||||
break;
|
||||
if let Some(content) = scan_close_tag(&mut self.buf, THINK_CLOSE, &mut self.think_buf) {
|
||||
self.in_think = false;
|
||||
let text = content.trim().to_string();
|
||||
if !text.is_empty() {
|
||||
self.push_child(ctx, AstNode::thinking(text));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if self.in_tool_call {
|
||||
match self.buf.find("</tool_call>") {
|
||||
Some(end) => {
|
||||
self.tool_call_buf.push_str(&self.buf[..end]);
|
||||
self.buf = self.buf[end + 12..].to_string();
|
||||
self.in_tool_call = false;
|
||||
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
||||
self.flush_content(ctx);
|
||||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||
self.call_counter += 1;
|
||||
pending.push(PendingToolCall {
|
||||
name,
|
||||
arguments: args,
|
||||
id: format!("call_{}", self.call_counter),
|
||||
});
|
||||
}
|
||||
self.tool_call_buf.clear();
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
let safe = self.buf.len().saturating_sub(12);
|
||||
if safe > 0 {
|
||||
let safe = self.buf.floor_char_boundary(safe);
|
||||
self.tool_call_buf.push_str(&self.buf[..safe]);
|
||||
self.buf = self.buf[safe..].to_string();
|
||||
}
|
||||
break;
|
||||
if let Some(content) = scan_close_tag(&mut self.buf, TOOL_CALL_CLOSE, &mut self.tool_call_buf) {
|
||||
self.in_tool_call = false;
|
||||
if let Some((name, args)) = parse_tool_call_body(&content) {
|
||||
self.flush_content(ctx);
|
||||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||
self.call_counter += 1;
|
||||
pending.push(PendingToolCall {
|
||||
name,
|
||||
arguments: args,
|
||||
id: format!("call_{}", self.call_counter),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let think_pos = self.buf.find("<think>");
|
||||
let tool_pos = self.buf.find("<tool_call>");
|
||||
let next_tag = match (think_pos, tool_pos) {
|
||||
(Some(a), Some(b)) => Some(a.min(b)),
|
||||
(Some(a), None) => Some(a),
|
||||
(None, Some(b)) => Some(b),
|
||||
(None, None) => None,
|
||||
};
|
||||
// Not inside a tag — find the earliest opening tag
|
||||
let next = OPEN_TAGS.iter()
|
||||
.filter_map(|tag| self.buf.find(tag).map(|pos| (pos, *tag)))
|
||||
.min_by_key(|(pos, _)| *pos);
|
||||
|
||||
match next_tag {
|
||||
Some(pos) => {
|
||||
match next {
|
||||
Some((pos, tag)) => {
|
||||
if pos > 0 {
|
||||
self.content_parts.push(self.buf[..pos].to_string());
|
||||
}
|
||||
if self.buf[pos..].starts_with("<think>") {
|
||||
self.buf = self.buf[pos + 7..].to_string();
|
||||
self.flush_content(ctx);
|
||||
self.in_think = true;
|
||||
} else {
|
||||
self.buf = self.buf[pos + 11..].to_string();
|
||||
self.flush_content(ctx);
|
||||
self.in_tool_call = true;
|
||||
self.buf = self.buf[pos + tag.len()..].to_string();
|
||||
self.flush_content(ctx);
|
||||
match tag {
|
||||
THINK_OPEN => self.in_think = true,
|
||||
TOOL_CALL_OPEN => self.in_tool_call = true,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
let safe = self.buf.len().saturating_sub(11);
|
||||
if safe > 0 {
|
||||
let safe = self.buf.floor_char_boundary(safe);
|
||||
self.content_parts.push(self.buf[..safe].to_string());
|
||||
self.buf = self.buf[safe..].to_string();
|
||||
// Keep a tail that might be a partial opening tag
|
||||
let max_tag = OPEN_TAGS.iter().map(|t| t.len()).max().unwrap();
|
||||
let drained = drain_safe(&mut self.buf, max_tag);
|
||||
if !drained.is_empty() {
|
||||
self.content_parts.push(drained);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -1008,7 +1026,9 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_tool_call_xml_parse_streamed_whitespace() {
|
||||
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</\nparameter\n>\n</\nfunction\n>";
|
||||
// Streaming tokenization can insert whitespace in opening tags,
|
||||
// but close tags are always emitted verbatim.
|
||||
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</parameter>\n</function>";
|
||||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||||
assert_eq!(name, "bash");
|
||||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||||
|
|
@ -1025,15 +1045,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_preserves_content() {
|
||||
let text = "<function=bash>\n<parameter=command>echo hello world</parameter>\n</function>";
|
||||
let normalized = normalize_xml_tags(text);
|
||||
assert_eq!(normalized, text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_strips_tag_internal_whitespace() {
|
||||
assert_eq!(normalize_xml_tags("<\nfunction\n=\nbash\n>"), "<function=bash>");
|
||||
fn test_tool_call_preserves_code_with_angle_brackets() {
|
||||
let body = "<function=edit>\n<parameter=code>if x < y {\n std::mem::swap(&mut a, &mut b);\n}</parameter>\n</function>";
|
||||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||||
assert_eq!(name, "edit");
|
||||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||||
assert_eq!(args["code"], "if x < y {\n std::mem::swap(&mut a, &mut b);\n}");
|
||||
}
|
||||
|
||||
// -- ResponseParser tests -------------------------------------------------
|
||||
|
|
@ -1047,7 +1064,7 @@ mod tests {
|
|||
let mut calls = Vec::new();
|
||||
for chunk in chunks {
|
||||
// Feed each chunk as a single token (id=0 for tests)
|
||||
calls.extend(p.feed_token(chunk, 0, &mut ctx));
|
||||
calls.extend(p.feed_token(chunk, &mut ctx));
|
||||
}
|
||||
p.finish(&mut ctx);
|
||||
(ctx, calls)
|
||||
|
|
@ -1109,7 +1126,7 @@ mod tests {
|
|||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||||
let mut p = ResponseParser::new(0);
|
||||
for ch in text.chars() {
|
||||
p.feed_token(&ch.to_string(), 0, &mut ctx);
|
||||
p.feed_token(&ch.to_string(), &mut ctx);
|
||||
}
|
||||
p.finish(&mut ctx);
|
||||
let b = bodies(assistant_children(&ctx));
|
||||
|
|
@ -1126,7 +1143,7 @@ mod tests {
|
|||
let mut p = ResponseParser::new(0);
|
||||
let mut tool_calls = 0;
|
||||
for ch in text.chars() {
|
||||
tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len();
|
||||
tool_calls += p.feed_token(&ch.to_string(), &mut ctx).len();
|
||||
}
|
||||
p.finish(&mut ctx);
|
||||
assert_eq!(tool_calls, 1);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue