Parsing fixes
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
b55230ce3f
commit
0af97774f4
2 changed files with 143 additions and 127 deletions
|
|
@ -45,7 +45,7 @@ pub(crate) struct SamplingParams {
|
||||||
|
|
||||||
/// One token from the streaming completions API.
|
/// One token from the streaming completions API.
|
||||||
pub enum StreamToken {
|
pub enum StreamToken {
|
||||||
Token { text: String, id: u32 },
|
Token(u32),
|
||||||
Done { usage: Option<Usage> },
|
Done { usage: Option<Usage> },
|
||||||
Error(String),
|
Error(String),
|
||||||
}
|
}
|
||||||
|
|
@ -159,20 +159,19 @@ async fn stream_completions(
|
||||||
};
|
};
|
||||||
|
|
||||||
for choice in choices {
|
for choice in choices {
|
||||||
let text = choice["text"].as_str().unwrap_or("");
|
if let Some(ids) = choice["token_ids"].as_array() {
|
||||||
let token_ids = choice["token_ids"].as_array();
|
for id_val in ids {
|
||||||
|
|
||||||
if let Some(ids) = token_ids {
|
|
||||||
for (i, id_val) in ids.iter().enumerate() {
|
|
||||||
if let Some(id) = id_val.as_u64() {
|
if let Some(id) = id_val.as_u64() {
|
||||||
let _ = tx.send(StreamToken::Token {
|
let _ = tx.send(StreamToken::Token(id as u32));
|
||||||
text: if i == 0 { text.to_string() } else { String::new() },
|
}
|
||||||
id: id as u32,
|
}
|
||||||
});
|
} else if let Some(text) = choice["text"].as_str() {
|
||||||
|
// Fallback: provider didn't return token_ids, encode locally
|
||||||
|
if !text.is_empty() {
|
||||||
|
for id in super::tokenizer::encode(text) {
|
||||||
|
let _ = tx.send(StreamToken::Token(id));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if !text.is_empty() {
|
|
||||||
let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 });
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -444,47 +444,57 @@ fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||||||
xml
|
xml
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_xml_tags(text: &str) -> String {
|
/// Search for a sequence of literal parts separated by optional ASCII whitespace.
|
||||||
let mut result = String::with_capacity(text.len());
|
/// Returns (start, end) byte positions of the overall match.
|
||||||
let mut chars = text.chars().peekable();
|
///
|
||||||
let mut in_tag = false;
|
/// Handles the case where streaming tokenization inserts whitespace inside
|
||||||
|
/// XML tag structure, e.g. `< function = bash >` instead of `<function=bash>`.
|
||||||
while let Some(ch) = chars.next() {
|
fn find_ws_seq(s: &str, parts: &[&str]) -> Option<(usize, usize)> {
|
||||||
if ch == '<' {
|
let bytes = s.as_bytes();
|
||||||
in_tag = true;
|
let mut search_from = 0;
|
||||||
result.push(ch);
|
'outer: loop {
|
||||||
} else if ch == '>' {
|
let start = s[search_from..].find(parts[0])? + search_from;
|
||||||
result.push(ch);
|
let mut pos = start + parts[0].len();
|
||||||
in_tag = false;
|
for &part in &parts[1..] {
|
||||||
} else if in_tag && ch.is_whitespace() {
|
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
|
||||||
// Skip whitespace inside tags (between < and >)
|
pos += 1;
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
result.push(ch);
|
|
||||||
}
|
}
|
||||||
|
if !s[pos..].starts_with(part) {
|
||||||
|
search_from = start + 1;
|
||||||
|
continue 'outer;
|
||||||
|
}
|
||||||
|
pos += part.len();
|
||||||
|
}
|
||||||
|
return Some((start, pos));
|
||||||
}
|
}
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse a Qwen-style XML tag: `<tag=name>body</tag>`.
|
||||||
|
/// Tolerates whitespace inside tag delimiters (streaming artifact).
|
||||||
|
/// Body content is returned verbatim except for a single leading/trailing
|
||||||
|
/// newline (XML formatting convention).
|
||||||
fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> {
|
fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> {
|
||||||
let open = format!("<{}=", tag);
|
// Open tag: tolerate whitespace from streaming tokenization
|
||||||
|
let (_, after_eq) = find_ws_seq(s, &["<", tag, "="])?;
|
||||||
|
let gt_offset = s[after_eq..].find('>')?;
|
||||||
|
let name = s[after_eq..after_eq + gt_offset].trim();
|
||||||
|
let body_start = after_eq + gt_offset + 1;
|
||||||
|
|
||||||
|
// Close tag: exact match — model doesn't insert whitespace in close tags
|
||||||
let close = format!("</{}>", tag);
|
let close = format!("</{}>", tag);
|
||||||
|
let close_offset = s[body_start..].find(&close)?;
|
||||||
|
let body = &s[body_start..body_start + close_offset];
|
||||||
|
// Strip the single leading/trailing newline from XML formatting,
|
||||||
|
// but preserve all other whitespace (indentation matters for code).
|
||||||
|
let body = body.strip_prefix('\n').unwrap_or(body);
|
||||||
|
let body = body.strip_suffix('\n').unwrap_or(body);
|
||||||
|
let rest = &s[body_start + close_offset + close.len()..];
|
||||||
|
|
||||||
let start = s.find(&open)? + open.len();
|
Some((name, body, rest))
|
||||||
let name_end = start + s[start..].find('>')?;
|
|
||||||
let body_start = name_end + 1;
|
|
||||||
let body_end = body_start + s[body_start..].find(&close)?;
|
|
||||||
|
|
||||||
Some((
|
|
||||||
s[start..name_end].trim(),
|
|
||||||
s[body_start..body_end].trim(),
|
|
||||||
&s[body_end + close.len()..],
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_tool_call_body(body: &str) -> Option<(String, String)> {
|
fn parse_tool_call_body(body: &str) -> Option<(String, String)> {
|
||||||
let normalized = normalize_xml_tags(body);
|
let body = body.trim();
|
||||||
let body = normalized.trim();
|
|
||||||
parse_xml_tool_call(body)
|
parse_xml_tool_call(body)
|
||||||
.or_else(|| parse_json_tool_call(body))
|
.or_else(|| parse_json_tool_call(body))
|
||||||
}
|
}
|
||||||
|
|
@ -509,6 +519,38 @@ fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
|
||||||
Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default()))
|
Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Search `buf` for `close_tag`. If found, append everything before it to
|
||||||
|
/// `accum`, advance `buf` past the tag, and return the accumulated content.
|
||||||
|
/// If not found, drain the safe prefix (preserving any partial tag match at
|
||||||
|
/// the end of buf) into `accum`.
|
||||||
|
fn scan_close_tag(buf: &mut String, close_tag: &str, accum: &mut String) -> Option<String> {
|
||||||
|
if let Some(pos) = buf.find(close_tag) {
|
||||||
|
accum.push_str(&buf[..pos]);
|
||||||
|
*buf = buf[pos + close_tag.len()..].to_string();
|
||||||
|
Some(std::mem::take(accum))
|
||||||
|
} else {
|
||||||
|
let drained = drain_safe(buf, close_tag.len());
|
||||||
|
if !drained.is_empty() {
|
||||||
|
accum.push_str(&drained);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove everything from `buf` except the last `tag_len` bytes, which might
|
||||||
|
/// be a partial tag. Returns the removed prefix.
|
||||||
|
fn drain_safe(buf: &mut String, tag_len: usize) -> String {
|
||||||
|
let safe = buf.len().saturating_sub(tag_len);
|
||||||
|
if safe > 0 {
|
||||||
|
let safe = buf.floor_char_boundary(safe);
|
||||||
|
let drained = buf[..safe].to_string();
|
||||||
|
*buf = buf[safe..].to_string();
|
||||||
|
drained
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ResponseParser {
|
impl ResponseParser {
|
||||||
pub fn new(branch_idx: usize) -> Self {
|
pub fn new(branch_idx: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -544,10 +586,11 @@ impl ResponseParser {
|
||||||
let mut full_text = String::new();
|
let mut full_text = String::new();
|
||||||
while let Some(event) = stream.recv().await {
|
while let Some(event) = stream.recv().await {
|
||||||
match event {
|
match event {
|
||||||
super::api::StreamToken::Token { text, id } => {
|
super::api::StreamToken::Token(id) => {
|
||||||
|
let text = super::tokenizer::decode(&[id]);
|
||||||
full_text.push_str(&text);
|
full_text.push_str(&text);
|
||||||
let mut ctx = agent.context.lock().await;
|
let mut ctx = agent.context.lock().await;
|
||||||
let calls = parser.feed_token(&text, id, &mut ctx);
|
let calls = parser.feed_token(&text, &mut ctx);
|
||||||
if !calls.is_empty() {
|
if !calls.is_empty() {
|
||||||
if let Some(ref mut f) = log_file {
|
if let Some(ref mut f) = log_file {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
@ -596,42 +639,33 @@ impl ResponseParser {
|
||||||
(rx, handle)
|
(rx, handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
pub fn feed_token(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||||
|
const THINK_OPEN: &str = "<think>";
|
||||||
|
const THINK_CLOSE: &str = "</think>";
|
||||||
|
const TOOL_CALL_OPEN: &str = "<tool_call>";
|
||||||
|
const TOOL_CALL_CLOSE: &str = "</tool_call>";
|
||||||
|
const OPEN_TAGS: &[&str] = &[THINK_OPEN, TOOL_CALL_OPEN];
|
||||||
|
|
||||||
let mut pending = Vec::new();
|
let mut pending = Vec::new();
|
||||||
self.buf.push_str(text);
|
self.buf.push_str(text);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if self.in_think {
|
if self.in_think {
|
||||||
match self.buf.find("</think>") {
|
if let Some(content) = scan_close_tag(&mut self.buf, THINK_CLOSE, &mut self.think_buf) {
|
||||||
Some(end) => {
|
|
||||||
self.think_buf.push_str(&self.buf[..end]);
|
|
||||||
self.buf = self.buf[end + 8..].to_string();
|
|
||||||
self.in_think = false;
|
self.in_think = false;
|
||||||
let text = std::mem::take(&mut self.think_buf).trim().to_string();
|
let text = content.trim().to_string();
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
self.push_child(ctx, AstNode::thinking(text));
|
self.push_child(ctx, AstNode::thinking(text));
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
|
||||||
let safe = self.buf.len().saturating_sub(8);
|
|
||||||
if safe > 0 {
|
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
|
||||||
self.think_buf.push_str(&self.buf[..safe]);
|
|
||||||
self.buf = self.buf[safe..].to_string();
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.in_tool_call {
|
if self.in_tool_call {
|
||||||
match self.buf.find("</tool_call>") {
|
if let Some(content) = scan_close_tag(&mut self.buf, TOOL_CALL_CLOSE, &mut self.tool_call_buf) {
|
||||||
Some(end) => {
|
|
||||||
self.tool_call_buf.push_str(&self.buf[..end]);
|
|
||||||
self.buf = self.buf[end + 12..].to_string();
|
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
if let Some((name, args)) = parse_tool_call_body(&content) {
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||||
self.call_counter += 1;
|
self.call_counter += 1;
|
||||||
|
|
@ -641,52 +675,36 @@ impl ResponseParser {
|
||||||
id: format!("call_{}", self.call_counter),
|
id: format!("call_{}", self.call_counter),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
self.tool_call_buf.clear();
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
|
||||||
let safe = self.buf.len().saturating_sub(12);
|
|
||||||
if safe > 0 {
|
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
|
||||||
self.tool_call_buf.push_str(&self.buf[..safe]);
|
|
||||||
self.buf = self.buf[safe..].to_string();
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let think_pos = self.buf.find("<think>");
|
// Not inside a tag — find the earliest opening tag
|
||||||
let tool_pos = self.buf.find("<tool_call>");
|
let next = OPEN_TAGS.iter()
|
||||||
let next_tag = match (think_pos, tool_pos) {
|
.filter_map(|tag| self.buf.find(tag).map(|pos| (pos, *tag)))
|
||||||
(Some(a), Some(b)) => Some(a.min(b)),
|
.min_by_key(|(pos, _)| *pos);
|
||||||
(Some(a), None) => Some(a),
|
|
||||||
(None, Some(b)) => Some(b),
|
|
||||||
(None, None) => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
match next_tag {
|
match next {
|
||||||
Some(pos) => {
|
Some((pos, tag)) => {
|
||||||
if pos > 0 {
|
if pos > 0 {
|
||||||
self.content_parts.push(self.buf[..pos].to_string());
|
self.content_parts.push(self.buf[..pos].to_string());
|
||||||
}
|
}
|
||||||
if self.buf[pos..].starts_with("<think>") {
|
self.buf = self.buf[pos + tag.len()..].to_string();
|
||||||
self.buf = self.buf[pos + 7..].to_string();
|
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
self.in_think = true;
|
match tag {
|
||||||
} else {
|
THINK_OPEN => self.in_think = true,
|
||||||
self.buf = self.buf[pos + 11..].to_string();
|
TOOL_CALL_OPEN => self.in_tool_call = true,
|
||||||
self.flush_content(ctx);
|
_ => unreachable!(),
|
||||||
self.in_tool_call = true;
|
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let safe = self.buf.len().saturating_sub(11);
|
// Keep a tail that might be a partial opening tag
|
||||||
if safe > 0 {
|
let max_tag = OPEN_TAGS.iter().map(|t| t.len()).max().unwrap();
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
let drained = drain_safe(&mut self.buf, max_tag);
|
||||||
self.content_parts.push(self.buf[..safe].to_string());
|
if !drained.is_empty() {
|
||||||
self.buf = self.buf[safe..].to_string();
|
self.content_parts.push(drained);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -1008,7 +1026,9 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_xml_parse_streamed_whitespace() {
|
fn test_tool_call_xml_parse_streamed_whitespace() {
|
||||||
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</\nparameter\n>\n</\nfunction\n>";
|
// Streaming tokenization can insert whitespace in opening tags,
|
||||||
|
// but close tags are always emitted verbatim.
|
||||||
|
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</parameter>\n</function>";
|
||||||
let (name, args) = parse_tool_call_body(body).unwrap();
|
let (name, args) = parse_tool_call_body(body).unwrap();
|
||||||
assert_eq!(name, "bash");
|
assert_eq!(name, "bash");
|
||||||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||||||
|
|
@ -1025,15 +1045,12 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_normalize_preserves_content() {
|
fn test_tool_call_preserves_code_with_angle_brackets() {
|
||||||
let text = "<function=bash>\n<parameter=command>echo hello world</parameter>\n</function>";
|
let body = "<function=edit>\n<parameter=code>if x < y {\n std::mem::swap(&mut a, &mut b);\n}</parameter>\n</function>";
|
||||||
let normalized = normalize_xml_tags(text);
|
let (name, args) = parse_tool_call_body(body).unwrap();
|
||||||
assert_eq!(normalized, text);
|
assert_eq!(name, "edit");
|
||||||
}
|
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||||||
|
assert_eq!(args["code"], "if x < y {\n std::mem::swap(&mut a, &mut b);\n}");
|
||||||
#[test]
|
|
||||||
fn test_normalize_strips_tag_internal_whitespace() {
|
|
||||||
assert_eq!(normalize_xml_tags("<\nfunction\n=\nbash\n>"), "<function=bash>");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- ResponseParser tests -------------------------------------------------
|
// -- ResponseParser tests -------------------------------------------------
|
||||||
|
|
@ -1047,7 +1064,7 @@ mod tests {
|
||||||
let mut calls = Vec::new();
|
let mut calls = Vec::new();
|
||||||
for chunk in chunks {
|
for chunk in chunks {
|
||||||
// Feed each chunk as a single token (id=0 for tests)
|
// Feed each chunk as a single token (id=0 for tests)
|
||||||
calls.extend(p.feed_token(chunk, 0, &mut ctx));
|
calls.extend(p.feed_token(chunk, &mut ctx));
|
||||||
}
|
}
|
||||||
p.finish(&mut ctx);
|
p.finish(&mut ctx);
|
||||||
(ctx, calls)
|
(ctx, calls)
|
||||||
|
|
@ -1109,7 +1126,7 @@ mod tests {
|
||||||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||||||
let mut p = ResponseParser::new(0);
|
let mut p = ResponseParser::new(0);
|
||||||
for ch in text.chars() {
|
for ch in text.chars() {
|
||||||
p.feed_token(&ch.to_string(), 0, &mut ctx);
|
p.feed_token(&ch.to_string(), &mut ctx);
|
||||||
}
|
}
|
||||||
p.finish(&mut ctx);
|
p.finish(&mut ctx);
|
||||||
let b = bodies(assistant_children(&ctx));
|
let b = bodies(assistant_children(&ctx));
|
||||||
|
|
@ -1126,7 +1143,7 @@ mod tests {
|
||||||
let mut p = ResponseParser::new(0);
|
let mut p = ResponseParser::new(0);
|
||||||
let mut tool_calls = 0;
|
let mut tool_calls = 0;
|
||||||
for ch in text.chars() {
|
for ch in text.chars() {
|
||||||
tool_calls += p.feed_token(&ch.to_string(), 0, &mut ctx).len();
|
tool_calls += p.feed_token(&ch.to_string(), &mut ctx).len();
|
||||||
}
|
}
|
||||||
p.finish(&mut ctx);
|
p.finish(&mut ctx);
|
||||||
assert_eq!(tool_calls, 1);
|
assert_eq!(tool_calls, 1);
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue