Fix parser: re-encode tokens instead of tracking model IDs through tag splits
The parser can't reliably split model-produced token IDs at tag boundaries (<think>, <tool_call>) because BPE tokens can span across tags. Instead, each leaf gets re-encoded from its text content via the local tokenizer. This gives clean token boundaries aligned with semantic structure — better for budgeting and potentially for the model during fine-tuning. Also skip serializing token_ids to conversation log (they're cached state, recomputed on construction). Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
88ac5e10ce
commit
5ec2ff95d8
1 changed files with 4 additions and 29 deletions
|
|
@ -66,6 +66,7 @@ pub enum NodeBody {
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct NodeLeaf {
|
pub struct NodeLeaf {
|
||||||
body: NodeBody,
|
body: NodeBody,
|
||||||
|
#[serde(skip)]
|
||||||
token_ids: Vec<u32>,
|
token_ids: Vec<u32>,
|
||||||
timestamp: Option<DateTime<Utc>>,
|
timestamp: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
@ -115,15 +116,11 @@ pub struct ResponseParser {
|
||||||
branch_idx: usize,
|
branch_idx: usize,
|
||||||
call_counter: u32,
|
call_counter: u32,
|
||||||
buf: String,
|
buf: String,
|
||||||
buf_token_ids: Vec<u32>,
|
|
||||||
content_parts: Vec<String>,
|
content_parts: Vec<String>,
|
||||||
content_token_ids: Vec<u32>,
|
|
||||||
in_think: bool,
|
in_think: bool,
|
||||||
think_buf: String,
|
think_buf: String,
|
||||||
think_token_ids: Vec<u32>,
|
|
||||||
in_tool_call: bool,
|
in_tool_call: bool,
|
||||||
tool_call_buf: String,
|
tool_call_buf: String,
|
||||||
tool_call_token_ids: Vec<u32>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Role {
|
impl Role {
|
||||||
|
|
@ -471,15 +468,11 @@ impl ResponseParser {
|
||||||
branch_idx,
|
branch_idx,
|
||||||
call_counter: 0,
|
call_counter: 0,
|
||||||
buf: String::new(),
|
buf: String::new(),
|
||||||
buf_token_ids: Vec::new(),
|
|
||||||
content_parts: Vec::new(),
|
content_parts: Vec::new(),
|
||||||
content_token_ids: Vec::new(),
|
|
||||||
in_think: false,
|
in_think: false,
|
||||||
think_buf: String::new(),
|
think_buf: String::new(),
|
||||||
think_token_ids: Vec::new(),
|
|
||||||
in_tool_call: false,
|
in_tool_call: false,
|
||||||
tool_call_buf: String::new(),
|
tool_call_buf: String::new(),
|
||||||
tool_call_token_ids: Vec::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -523,23 +516,19 @@ impl ResponseParser {
|
||||||
(rx, handle)
|
(rx, handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn feed_token(&mut self, text: &str, token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
pub fn feed_token(&mut self, text: &str, _token_id: u32, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||||||
let mut pending = Vec::new();
|
let mut pending = Vec::new();
|
||||||
self.buf.push_str(text);
|
self.buf.push_str(text);
|
||||||
self.buf_token_ids.push(token_id);
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if self.in_think {
|
if self.in_think {
|
||||||
match self.buf.find("</think>") {
|
match self.buf.find("</think>") {
|
||||||
Some(end) => {
|
Some(end) => {
|
||||||
self.think_buf.push_str(&self.buf[..end]);
|
self.think_buf.push_str(&self.buf[..end]);
|
||||||
// Token IDs: move all buffered IDs to think (approximate split)
|
|
||||||
self.think_token_ids.extend(self.buf_token_ids.drain(..));
|
|
||||||
self.buf = self.buf[end + 8..].to_string();
|
self.buf = self.buf[end + 8..].to_string();
|
||||||
self.in_think = false;
|
self.in_think = false;
|
||||||
let text = std::mem::take(&mut self.think_buf);
|
let text = std::mem::take(&mut self.think_buf);
|
||||||
let ids = std::mem::take(&mut self.think_token_ids);
|
self.push_child(ctx, AstNode::thinking(text));
|
||||||
self.push_child_with_tokens(ctx, NodeBody::Thinking(text), ids);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
|
|
@ -548,7 +537,6 @@ impl ResponseParser {
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
let safe = self.buf.floor_char_boundary(safe);
|
||||||
self.think_buf.push_str(&self.buf[..safe]);
|
self.think_buf.push_str(&self.buf[..safe]);
|
||||||
self.buf = self.buf[safe..].to_string();
|
self.buf = self.buf[safe..].to_string();
|
||||||
// Keep token IDs in buf (lookahead)
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -559,12 +547,10 @@ impl ResponseParser {
|
||||||
match self.buf.find("</tool_call>") {
|
match self.buf.find("</tool_call>") {
|
||||||
Some(end) => {
|
Some(end) => {
|
||||||
self.tool_call_buf.push_str(&self.buf[..end]);
|
self.tool_call_buf.push_str(&self.buf[..end]);
|
||||||
self.tool_call_token_ids.extend(self.buf_token_ids.drain(..));
|
|
||||||
self.buf = self.buf[end + 12..].to_string();
|
self.buf = self.buf[end + 12..].to_string();
|
||||||
self.in_tool_call = false;
|
self.in_tool_call = false;
|
||||||
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
if let Some((name, args)) = parse_tool_call_body(&self.tool_call_buf) {
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
// Tool calls get re-tokenized from structured data
|
|
||||||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||||||
self.call_counter += 1;
|
self.call_counter += 1;
|
||||||
pending.push(PendingToolCall {
|
pending.push(PendingToolCall {
|
||||||
|
|
@ -574,7 +560,6 @@ impl ResponseParser {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
self.tool_call_buf.clear();
|
self.tool_call_buf.clear();
|
||||||
self.tool_call_token_ids.clear();
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
|
|
@ -603,8 +588,6 @@ impl ResponseParser {
|
||||||
if pos > 0 {
|
if pos > 0 {
|
||||||
self.content_parts.push(self.buf[..pos].to_string());
|
self.content_parts.push(self.buf[..pos].to_string());
|
||||||
}
|
}
|
||||||
// Move token IDs to content accumulator
|
|
||||||
self.content_token_ids.extend(self.buf_token_ids.drain(..));
|
|
||||||
if self.buf[pos..].starts_with("<think>") {
|
if self.buf[pos..].starts_with("<think>") {
|
||||||
self.buf = self.buf[pos + 7..].to_string();
|
self.buf = self.buf[pos + 7..].to_string();
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
|
|
@ -622,7 +605,6 @@ impl ResponseParser {
|
||||||
let safe = self.buf.floor_char_boundary(safe);
|
let safe = self.buf.floor_char_boundary(safe);
|
||||||
self.content_parts.push(self.buf[..safe].to_string());
|
self.content_parts.push(self.buf[..safe].to_string());
|
||||||
self.buf = self.buf[safe..].to_string();
|
self.buf = self.buf[safe..].to_string();
|
||||||
// Keep token IDs in buf (lookahead)
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -636,17 +618,11 @@ impl ResponseParser {
|
||||||
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn push_child_with_tokens(&self, ctx: &mut ContextState, body: NodeBody, token_ids: Vec<u32>) {
|
|
||||||
let leaf = NodeLeaf { body, token_ids, timestamp: None };
|
|
||||||
ctx.push_child(Section::Conversation, self.branch_idx, AstNode::Leaf(leaf));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush_content(&mut self, ctx: &mut ContextState) {
|
fn flush_content(&mut self, ctx: &mut ContextState) {
|
||||||
if !self.content_parts.is_empty() {
|
if !self.content_parts.is_empty() {
|
||||||
let text: String = self.content_parts.drain(..).collect();
|
let text: String = self.content_parts.drain(..).collect();
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
let token_ids = std::mem::take(&mut self.content_token_ids);
|
self.push_child(ctx, AstNode::content(text));
|
||||||
self.push_child_with_tokens(ctx, NodeBody::Content(text), token_ids);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -654,7 +630,6 @@ impl ResponseParser {
|
||||||
pub fn finish(mut self, ctx: &mut ContextState) {
|
pub fn finish(mut self, ctx: &mut ContextState) {
|
||||||
if !self.buf.is_empty() {
|
if !self.buf.is_empty() {
|
||||||
self.content_parts.push(std::mem::take(&mut self.buf));
|
self.content_parts.push(std::mem::take(&mut self.buf));
|
||||||
self.content_token_ids.extend(self.buf_token_ids.drain(..));
|
|
||||||
}
|
}
|
||||||
self.flush_content(ctx);
|
self.flush_content(ctx);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue