Fix parser: re-encode tokens instead of tracking model IDs through tag splits
The parser can't reliably split model-produced token IDs at tag boundaries (<think>, <tool_call>) because BPE tokens can span across tags. Instead, each leaf gets re-encoded from its text content via the local tokenizer. This gives clean token boundaries aligned with semantic structure — better for budgeting and potentially for the model during fine-tuning. Also skip serializing token_ids to conversation log (they're cached state, recomputed on construction). Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
88ac5e10ce
commit
5ec2ff95d8
1 changed files with 4 additions and 29 deletions
|
|
@ -66,6 +66,7 @@ pub enum NodeBody {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NodeLeaf {
|
||||
body: NodeBody,
|
||||
#[serde(skip)]
|
||||
token_ids: Vec<u32>,
|
||||
timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
|
@ -115,15 +116,11 @@ pub struct ResponseParser {
|
|||
branch_idx: usize,
|
||||
call_counter: u32,
|
||||
buf: String,
|
||||
buf_token_ids: Vec<u32>,
|
||||
content_parts: Vec<String>,
|
||||
content_token_ids: Vec<u32>,
|
||||
in_think: bool,
|
||||
think_buf: String,
|
||||
think_token_ids: Vec<u32>,
|
||||
in_tool_call: bool,
|
||||
tool_call_buf: String,
|
||||
tool_call_token_ids: Vec<u32>,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
|
|
@ -471,15 +468,11 @@ impl ResponseParser {
|
|||
branch_idx,
|
||||
call_counter: 0,
|
||||
buf: String::new(),
|
||||
buf_token_ids: Vec::new(),
|
||||
content_parts: Vec::new(),
|
||||
content_token_ids: Vec::new(),
|
||||
in_think: false,
|
||||
think_buf: String::new(),
|
||||
think_token_ids: Vec::new(),
|
||||
in_tool_call: false,
|
||||
tool_call_buf: String::new(),
|
||||
tool_call_token_ids: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -523,23 +516,19 @@ impl ResponseParser {
|
|||
(rx, handle)
|
||||
}
|
||||
|
||||
pub fn feed_token(&mut self, text: &str, token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||
pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||
let mut pending = Vec::new();
|
||||
self.buf.push_str(text);
|
||||
self.buf_token_ids.push(token_id);
|
||||
|
||||
loop {
|
||||
if self.in_think {
|
||||
match self.buf.find("</think>") {
|
||||
Some(end) => {
|
||||
self.think_buf.push_str(&self.buf[..end]);
|
||||
// Token IDs: move all buffered IDs to think (approximate split)
|
||||
self.think_token_ids.extend(self.buf_token_ids.drain(..));
|
||||
self.buf = self.buf[end + 8..].to_string();
|
||||
self.in_think = false;
|
||||
let text = std::mem::take(&mut self.think_buf);
|
||||
let ids = std::mem::take(&mut self.think_token_ids);
|
||||
self.push_child_with_tokens(ctx, NodeBody::Thinking(text), ids);
|
||||
self.push_child(ctx, AstNode::thinking(text));
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
|
|
@ -548,7 +537,6 @@ impl ResponseParser {
|
|||
let safe = self.buf.floor_char_boundary(safe);
|
||||
self.think_buf.push_str(&self.buf[..safe]);
|
||||
self.buf = self.buf[safe..].to_string();
|
||||
// Keep token IDs in buf (lookahead)
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -559,12 +547,10 @@ impl ResponseParser {
|
|||
match self.buf.find("</tool_call>") {
|
||||
Some(end) => {
|
||||
self.tool_call_buf.push_str(&self.buf[..end]);
|
||||
self.tool_call_token_ids.extend(self.buf_token_ids.drain(..));
|
||||
self.buf = self.buf[end + 12..].to_string();
|
||||
self.in_tool_call = false;
|
||||
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
||||
self.flush_content(ctx);
|
||||
// Tool calls get re-tokenized from structured data
|
||||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||
self.call_counter += 1;
|
||||
pending.push(PendingToolCall {
|
||||
|
|
@ -574,7 +560,6 @@ impl ResponseParser {
|
|||
});
|
||||
}
|
||||
self.tool_call_buf.clear();
|
||||
self.tool_call_token_ids.clear();
|
||||
continue;
|
||||
}
|
||||
None => {
|
||||
|
|
@ -603,8 +588,6 @@ impl ResponseParser {
|
|||
if pos > 0 {
|
||||
self.content_parts.push(self.buf[..pos].to_string());
|
||||
}
|
||||
// Move token IDs to content accumulator
|
||||
self.content_token_ids.extend(self.buf_token_ids.drain(..));
|
||||
if self.buf[pos..].starts_with("<think>") {
|
||||
self.buf = self.buf[pos + 7..].to_string();
|
||||
self.flush_content(ctx);
|
||||
|
|
@ -622,7 +605,6 @@ impl ResponseParser {
|
|||
let safe = self.buf.floor_char_boundary(safe);
|
||||
self.content_parts.push(self.buf[..safe].to_string());
|
||||
self.buf = self.buf[safe..].to_string();
|
||||
// Keep token IDs in buf (lookahead)
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -636,17 +618,11 @@ impl ResponseParser {
|
|||
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
||||
}
|
||||
|
||||
fn push_child_with_tokens(&self, ctx: &mut ContextState, body: NodeBody, token_ids: Vec<u32>) {
|
||||
let leaf = NodeLeaf { body, token_ids, timestamp: None };
|
||||
ctx.push_child(Section::Conversation, self.branch_idx, AstNode::Leaf(leaf));
|
||||
}
|
||||
|
||||
fn flush_content(&mut self, ctx: &mut ContextState) {
|
||||
if !self.content_parts.is_empty() {
|
||||
let text: String = self.content_parts.drain(..).collect();
|
||||
if !text.is_empty() {
|
||||
let token_ids = std::mem::take(&mut self.content_token_ids);
|
||||
self.push_child_with_tokens(ctx, NodeBody::Content(text), token_ids);
|
||||
self.push_child(ctx, AstNode::content(text));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -654,7 +630,6 @@ impl ResponseParser {
|
|||
pub fn finish(mut self, ctx: &mut ContextState) {
|
||||
if !self.buf.is_empty() {
|
||||
self.content_parts.push(std::mem::take(&mut self.buf));
|
||||
self.content_token_ids.extend(self.buf_token_ids.drain(..));
|
||||
}
|
||||
self.flush_content(ctx);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue