Fix parser: re-encode tokens instead of tracking model IDs through tag splits

The parser can't reliably split model-produced token IDs at tag
boundaries (<think>, <tool_call>) because BPE tokens can span across
tags. Instead, each leaf gets re-encoded from its text content via
the local tokenizer. This gives clean token boundaries aligned with
semantic structure — better for budgeting and potentially for the
model during fine-tuning.

Also skip serializing token_ids to conversation log (they're cached
state, recomputed on construction).

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-08 17:08:42 -04:00
parent 88ac5e10ce
commit 5ec2ff95d8

View file

@ -66,6 +66,7 @@ pub enum NodeBody {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeLeaf { pub struct NodeLeaf {
body: NodeBody, body: NodeBody,
#[serde(skip)]
token_ids: Vec<u32>, token_ids: Vec<u32>,
timestamp: Option<DateTime<Utc>>, timestamp: Option<DateTime<Utc>>,
} }
@ -115,15 +116,11 @@ pub struct ResponseParser {
branch_idx: usize, branch_idx: usize,
call_counter: u32, call_counter: u32,
buf: String, buf: String,
buf_token_ids: Vec<u32>,
content_parts: Vec<String>, content_parts: Vec<String>,
content_token_ids: Vec<u32>,
in_think: bool, in_think: bool,
think_buf: String, think_buf: String,
think_token_ids: Vec<u32>,
in_tool_call: bool, in_tool_call: bool,
tool_call_buf: String, tool_call_buf: String,
tool_call_token_ids: Vec<u32>,
} }
impl Role { impl Role {
@ -471,15 +468,11 @@ impl ResponseParser {
branch_idx, branch_idx,
call_counter: 0, call_counter: 0,
buf: String::new(), buf: String::new(),
buf_token_ids: Vec::new(),
content_parts: Vec::new(), content_parts: Vec::new(),
content_token_ids: Vec::new(),
in_think: false, in_think: false,
think_buf: String::new(), think_buf: String::new(),
think_token_ids: Vec::new(),
in_tool_call: false, in_tool_call: false,
tool_call_buf: String::new(), tool_call_buf: String::new(),
tool_call_token_ids: Vec::new(),
} }
} }
@ -523,23 +516,19 @@ impl ResponseParser {
(rx, handle) (rx, handle)
} }
pub fn feed_token(&mut self, text: &str, token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> { pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
let mut pending = Vec::new(); let mut pending = Vec::new();
self.buf.push_str(text); self.buf.push_str(text);
self.buf_token_ids.push(token_id);
loop { loop {
if self.in_think { if self.in_think {
match self.buf.find("</think>") { match self.buf.find("</think>") {
Some(end) => { Some(end) => {
self.think_buf.push_str(&self.buf[..end]); self.think_buf.push_str(&self.buf[..end]);
// Token IDs: move all buffered IDs to think (approximate split)
self.think_token_ids.extend(self.buf_token_ids.drain(..));
self.buf = self.buf[end + 8..].to_string(); self.buf = self.buf[end + 8..].to_string();
self.in_think = false; self.in_think = false;
let text = std::mem::take(&mut self.think_buf); let text = std::mem::take(&mut self.think_buf);
let ids = std::mem::take(&mut self.think_token_ids); self.push_child(ctx, AstNode::thinking(text));
self.push_child_with_tokens(ctx, NodeBody::Thinking(text), ids);
continue; continue;
} }
None => { None => {
@ -548,7 +537,6 @@ impl ResponseParser {
let safe = self.buf.floor_char_boundary(safe); let safe = self.buf.floor_char_boundary(safe);
self.think_buf.push_str(&self.buf[..safe]); self.think_buf.push_str(&self.buf[..safe]);
self.buf = self.buf[safe..].to_string(); self.buf = self.buf[safe..].to_string();
// Keep token IDs in buf (lookahead)
} }
break; break;
} }
@ -559,12 +547,10 @@ impl ResponseParser {
match self.buf.find("</tool_call>") { match self.buf.find("</tool_call>") {
Some(end) => { Some(end) => {
self.tool_call_buf.push_str(&self.buf[..end]); self.tool_call_buf.push_str(&self.buf[..end]);
self.tool_call_token_ids.extend(self.buf_token_ids.drain(..));
self.buf = self.buf[end + 12..].to_string(); self.buf = self.buf[end + 12..].to_string();
self.in_tool_call = false; self.in_tool_call = false;
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) { if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
self.flush_content(ctx); self.flush_content(ctx);
// Tool calls get re-tokenized from structured data
self.push_child(ctx, AstNode::tool_call(&name, &args)); self.push_child(ctx, AstNode::tool_call(&name, &args));
self.call_counter += 1; self.call_counter += 1;
pending.push(PendingToolCall { pending.push(PendingToolCall {
@ -574,7 +560,6 @@ impl ResponseParser {
}); });
} }
self.tool_call_buf.clear(); self.tool_call_buf.clear();
self.tool_call_token_ids.clear();
continue; continue;
} }
None => { None => {
@ -603,8 +588,6 @@ impl ResponseParser {
if pos > 0 { if pos > 0 {
self.content_parts.push(self.buf[..pos].to_string()); self.content_parts.push(self.buf[..pos].to_string());
} }
// Move token IDs to content accumulator
self.content_token_ids.extend(self.buf_token_ids.drain(..));
if self.buf[pos..].starts_with("<think>") { if self.buf[pos..].starts_with("<think>") {
self.buf = self.buf[pos + 7..].to_string(); self.buf = self.buf[pos + 7..].to_string();
self.flush_content(ctx); self.flush_content(ctx);
@ -622,7 +605,6 @@ impl ResponseParser {
let safe = self.buf.floor_char_boundary(safe); let safe = self.buf.floor_char_boundary(safe);
self.content_parts.push(self.buf[..safe].to_string()); self.content_parts.push(self.buf[..safe].to_string());
self.buf = self.buf[safe..].to_string(); self.buf = self.buf[safe..].to_string();
// Keep token IDs in buf (lookahead)
} }
break; break;
} }
@ -636,17 +618,11 @@ impl ResponseParser {
ctx.push_child(Section::Conversation, self.branch_idx, child); ctx.push_child(Section::Conversation, self.branch_idx, child);
} }
fn push_child_with_tokens(&self, ctx: &mut ContextState, body: NodeBody, token_ids: Vec<u32>) {
let leaf = NodeLeaf { body, token_ids, timestamp: None };
ctx.push_child(Section::Conversation, self.branch_idx, AstNode::Leaf(leaf));
}
fn flush_content(&mut self, ctx: &mut ContextState) { fn flush_content(&mut self, ctx: &mut ContextState) {
if !self.content_parts.is_empty() { if !self.content_parts.is_empty() {
let text: String = self.content_parts.drain(..).collect(); let text: String = self.content_parts.drain(..).collect();
if !text.is_empty() { if !text.is_empty() {
let token_ids = std::mem::take(&mut self.content_token_ids); self.push_child(ctx, AstNode::content(text));
self.push_child_with_tokens(ctx, NodeBody::Content(text), token_ids);
} }
} }
} }
@ -654,7 +630,6 @@ impl ResponseParser {
pub fn finish(mut self, ctx: &mut ContextState) { pub fn finish(mut self, ctx: &mut ContextState) {
if !self.buf.is_empty() { if !self.buf.is_empty() {
self.content_parts.push(std::mem::take(&mut self.buf)); self.content_parts.push(std::mem::take(&mut self.buf));
self.content_token_ids.extend(self.buf_token_ids.drain(..));
} }
self.flush_content(ctx); self.flush_content(ctx);
} }