Replace token counting with token generation via HuggingFace tokenizer
Add agent/tokenizer.rs with global Qwen 3.5 tokenizer that generates actual token IDs including chat template wrapping. ContextEntry now stores token_ids: Vec<u32> instead of tokens: usize — the count is derived from the length. ContextEntry::new() tokenizes automatically via the global tokenizer. ContextSection::push_entry() takes a raw ConversationEntry and tokenizes it. set_message() re-tokenizes without needing an external tokenizer parameter. Token IDs include the full chat template: <|im_start|>role\ncontent <|im_end|>\n — so concatenating token_ids across entries produces a ready-to-send prompt for vLLM's /v1/completions endpoint. The old tiktoken CoreBPE is now unused on Agent (will be removed in a followup). Token counts are now exact for Qwen 3.5 instead of the ~85-90% approximation from cl100k_base. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
70ee7abea5
commit
5e4067c04f
10 changed files with 540 additions and 97 deletions
|
|
@ -33,12 +33,25 @@ pub enum ConversationEntry {
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct ContextEntry {
|
||||
pub entry: ConversationEntry,
|
||||
/// Cached token count (0 for Log entries).
|
||||
pub tokens: usize,
|
||||
/// Cached tokenization — the actual token IDs for this entry's
|
||||
/// contribution to the prompt (including chat template wrapping).
|
||||
/// Empty for Log entries.
|
||||
pub token_ids: Vec<u32>,
|
||||
/// When this entry was added to the context.
|
||||
pub timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl ContextEntry {
|
||||
/// Create a new entry, tokenizing via the global tokenizer.
|
||||
pub fn new(entry: ConversationEntry, timestamp: Option<DateTime<Utc>>) -> Self {
|
||||
let token_ids = super::tokenizer::tokenize_conv_entry(&entry);
|
||||
Self { entry, token_ids, timestamp }
|
||||
}
|
||||
|
||||
/// Token count — derived from cached token_ids length.
|
||||
pub fn tokens(&self) -> usize { self.token_ids.len() }
|
||||
}
|
||||
|
||||
/// A named section of the context window with cached token total.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextSection {
|
||||
|
|
@ -58,32 +71,40 @@ impl ContextSection {
|
|||
pub fn len(&self) -> usize { self.entries.len() }
|
||||
pub fn is_empty(&self) -> bool { self.entries.is_empty() }
|
||||
|
||||
/// Push an entry, updating the cached token total.
|
||||
/// Push a ConversationEntry, tokenizing it and updating the total.
|
||||
pub fn push_entry(&mut self, entry: ConversationEntry, timestamp: Option<DateTime<Utc>>) {
|
||||
let ce = ContextEntry::new(entry, timestamp);
|
||||
self.tokens += ce.tokens();
|
||||
self.entries.push(ce);
|
||||
}
|
||||
|
||||
/// Push a pre-built ContextEntry (for restore, cloning, etc).
|
||||
pub fn push(&mut self, entry: ContextEntry) {
|
||||
self.tokens += entry.tokens;
|
||||
self.tokens += entry.tokens();
|
||||
self.entries.push(entry);
|
||||
}
|
||||
|
||||
/// Replace an entry at `index`, adjusting the token total.
|
||||
pub fn set(&mut self, index: usize, entry: ContextEntry) {
|
||||
self.tokens -= self.entries[index].tokens;
|
||||
self.tokens += entry.tokens;
|
||||
self.tokens -= self.entries[index].tokens();
|
||||
self.tokens += entry.tokens();
|
||||
self.entries[index] = entry;
|
||||
}
|
||||
|
||||
/// Remove an entry at `index`, adjusting the token total.
|
||||
pub fn del(&mut self, index: usize) -> ContextEntry {
|
||||
let removed = self.entries.remove(index);
|
||||
self.tokens -= removed.tokens;
|
||||
self.tokens -= removed.tokens();
|
||||
removed
|
||||
}
|
||||
|
||||
/// Replace the message inside an entry, recomputing its token count.
|
||||
pub fn set_message(&mut self, index: usize, tokenizer: &CoreBPE, msg: Message) {
|
||||
let old_tokens = self.entries[index].tokens;
|
||||
/// Replace the message inside an entry, re-tokenizing it.
|
||||
pub fn set_message(&mut self, index: usize, msg: Message) {
|
||||
let old_tokens = self.entries[index].tokens();
|
||||
*self.entries[index].entry.message_mut() = msg;
|
||||
let new_tokens = msg_token_count(tokenizer, self.entries[index].entry.api_message());
|
||||
self.entries[index].tokens = new_tokens;
|
||||
self.entries[index].token_ids = super::tokenizer::tokenize_conv_entry(
|
||||
&self.entries[index].entry);
|
||||
let new_tokens = self.entries[index].tokens();
|
||||
self.tokens = self.tokens - old_tokens + new_tokens;
|
||||
}
|
||||
|
||||
|
|
@ -96,7 +117,7 @@ impl ContextSection {
|
|||
|
||||
/// Bulk replace all entries, recomputing token total.
|
||||
pub fn set_entries(&mut self, entries: Vec<ContextEntry>) {
|
||||
self.tokens = entries.iter().map(|e| e.tokens).sum();
|
||||
self.tokens = entries.iter().map(|e| e.tokens()).sum();
|
||||
self.entries = entries;
|
||||
}
|
||||
|
||||
|
|
@ -104,7 +125,7 @@ impl ContextSection {
|
|||
pub fn trim(&mut self, fixed_tokens: usize) {
|
||||
let result = trim_entries(&self.entries, fixed_tokens);
|
||||
self.entries = result;
|
||||
self.tokens = self.entries.iter().map(|e| e.tokens).sum();
|
||||
self.tokens = self.entries.iter().map(|e| e.tokens()).sum();
|
||||
}
|
||||
|
||||
/// Clear all entries.
|
||||
|
|
@ -189,9 +210,9 @@ fn trim_entries(entries: &[ContextEntry], fixed_tokens: usize) -> Vec<ContextEnt
|
|||
.map(|(_, e)| e.clone())
|
||||
.collect();
|
||||
|
||||
let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens).sum::<usize>() };
|
||||
let entry_total = |r: &[ContextEntry]| -> usize { r.iter().map(|e| e.tokens()).sum::<usize>() };
|
||||
let mem_total = |r: &[ContextEntry]| -> usize {
|
||||
r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens).sum()
|
||||
r.iter().filter(|e| e.entry.is_memory()).map(|e| e.tokens()).sum()
|
||||
};
|
||||
|
||||
dbglog!("[trim] max={} fixed={} total={} entries={}",
|
||||
|
|
|
|||
108
src/agent/mod.rs
108
src/agent/mod.rs
|
|
@ -16,6 +16,7 @@
|
|||
pub mod api;
|
||||
pub mod context;
|
||||
pub mod oneshot;
|
||||
pub mod tokenizer;
|
||||
pub mod tools;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
|
@ -196,19 +197,12 @@ impl Agent {
|
|||
.expect("failed to load cl100k_base tokenizer");
|
||||
|
||||
let mut system = ContextSection::new("System prompt");
|
||||
system.push(ContextEntry {
|
||||
entry: ConversationEntry::System(Message::system(&system_prompt)),
|
||||
tokens: context::msg_token_count(&tokenizer, &Message::system(&system_prompt)),
|
||||
timestamp: None,
|
||||
});
|
||||
system.push(ContextEntry::new(
|
||||
ConversationEntry::System(Message::system(&system_prompt)), None));
|
||||
let mut identity = ContextSection::new("Identity");
|
||||
for (_name, content) in &personality {
|
||||
let msg = Message::user(content);
|
||||
identity.push(ContextEntry {
|
||||
tokens: context::msg_token_count(&tokenizer, &msg),
|
||||
entry: ConversationEntry::Message(msg),
|
||||
timestamp: None,
|
||||
});
|
||||
identity.push(ContextEntry::new(
|
||||
ConversationEntry::Message(Message::user(content)), None));
|
||||
}
|
||||
let context = ContextState {
|
||||
system,
|
||||
|
|
@ -324,12 +318,8 @@ impl Agent {
|
|||
eprintln!("warning: failed to log entry: {:#}", e);
|
||||
}
|
||||
}
|
||||
let tokens = if entry.is_log() || entry.is_thinking() { 0 } else {
|
||||
context::msg_token_count(&self.tokenizer, entry.api_message())
|
||||
};
|
||||
self.context.conversation.push(ContextEntry {
|
||||
entry, tokens, timestamp: Some(chrono::Utc::now()),
|
||||
});
|
||||
self.context.conversation.push(ContextEntry::new(
|
||||
entry, Some(chrono::Utc::now())));
|
||||
|
||||
self.changed.notify_one();
|
||||
}
|
||||
|
|
@ -348,22 +338,19 @@ impl Agent {
|
|||
if let Some(idx) = self.streaming_index() {
|
||||
let mut msg = self.context.conversation.entries()[idx].entry.message().clone();
|
||||
msg.append_content(text);
|
||||
self.context.conversation.set_message(idx, &self.tokenizer, msg);
|
||||
self.context.conversation.set_message(idx, msg);
|
||||
} else {
|
||||
let msg = Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(text.to_string())),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
timestamp: None,
|
||||
};
|
||||
let tokens = context::msg_token_count(&self.tokenizer, &msg);
|
||||
self.context.conversation.push(ContextEntry {
|
||||
entry: ConversationEntry::Message(msg),
|
||||
tokens,
|
||||
timestamp: None,
|
||||
});
|
||||
self.context.conversation.push(ContextEntry::new(
|
||||
ConversationEntry::Message(Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(text.to_string())),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
timestamp: None,
|
||||
}),
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
self.changed.notify_one();
|
||||
|
|
@ -375,12 +362,10 @@ impl Agent {
|
|||
if let Some(i) = self.streaming_index() {
|
||||
let mut stamped = msg.clone();
|
||||
stamped.stamp();
|
||||
let tokens = context::msg_token_count(&self.tokenizer, &stamped);
|
||||
self.context.conversation.set(i, ContextEntry {
|
||||
entry: ConversationEntry::Message(stamped),
|
||||
tokens,
|
||||
timestamp: Some(chrono::Utc::now()),
|
||||
});
|
||||
self.context.conversation.set(i, ContextEntry::new(
|
||||
ConversationEntry::Message(stamped),
|
||||
Some(chrono::Utc::now()),
|
||||
));
|
||||
} else {
|
||||
self.push_message(msg.clone());
|
||||
}
|
||||
|
|
@ -770,16 +755,15 @@ impl Agent {
|
|||
|
||||
for node in journal_nodes[..cutoff_idx].iter().rev() {
|
||||
let msg = Message::user(&node.content);
|
||||
let tokens = context::msg_token_count(&self.tokenizer, &msg);
|
||||
if total_tokens + tokens > journal_budget && !journal_entries.is_empty() {
|
||||
let ce = ContextEntry::new(
|
||||
ConversationEntry::Message(msg),
|
||||
chrono::DateTime::from_timestamp(node.created_at, 0),
|
||||
);
|
||||
if total_tokens + ce.tokens() > journal_budget && !journal_entries.is_empty() {
|
||||
break;
|
||||
}
|
||||
journal_entries.push(ContextEntry {
|
||||
entry: ConversationEntry::Message(msg),
|
||||
tokens,
|
||||
timestamp: chrono::DateTime::from_timestamp(node.created_at, 0),
|
||||
});
|
||||
total_tokens += tokens;
|
||||
total_tokens += ce.tokens();
|
||||
journal_entries.push(ce);
|
||||
}
|
||||
journal_entries.reverse();
|
||||
dbg_log!("[journal] loaded {} entries, {} tokens", journal_entries.len(), total_tokens);
|
||||
|
|
@ -842,12 +826,10 @@ impl Agent {
|
|||
}
|
||||
let mut new_msg = msg.clone();
|
||||
new_msg.content = Some(MessageContent::Text(replacement));
|
||||
let tokens = context::msg_token_count(&self.tokenizer, &new_msg);
|
||||
self.context.conversation.set(i, ContextEntry {
|
||||
entry: ConversationEntry::Message(new_msg),
|
||||
tokens,
|
||||
timestamp: old.timestamp,
|
||||
});
|
||||
self.context.conversation.set(i, ContextEntry::new(
|
||||
ConversationEntry::Message(new_msg),
|
||||
old.timestamp,
|
||||
));
|
||||
}
|
||||
}
|
||||
self.generation += 1;
|
||||
|
|
@ -866,19 +848,12 @@ impl Agent {
|
|||
match crate::config::reload_for_model(&self.app_config, &self.prompt_file) {
|
||||
Ok((system_prompt, personality)) => {
|
||||
self.context.system.clear();
|
||||
self.context.system.push(ContextEntry {
|
||||
entry: ConversationEntry::System(Message::system(&system_prompt)),
|
||||
tokens: context::msg_token_count(&self.tokenizer, &Message::system(&system_prompt)),
|
||||
timestamp: None,
|
||||
});
|
||||
self.context.system.push(ContextEntry::new(
|
||||
ConversationEntry::System(Message::system(&system_prompt)), None));
|
||||
self.context.identity.clear();
|
||||
for (_name, content) in &personality {
|
||||
let msg = Message::user(content);
|
||||
self.context.identity.push(ContextEntry {
|
||||
tokens: context::msg_token_count(&self.tokenizer, &msg),
|
||||
entry: ConversationEntry::Message(msg),
|
||||
timestamp: None,
|
||||
});
|
||||
self.context.identity.push(ContextEntry::new(
|
||||
ConversationEntry::Message(Message::user(content)), None));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -932,16 +907,13 @@ impl Agent {
|
|||
let all: Vec<ContextEntry> = entries.into_iter()
|
||||
.filter(|e| !e.is_log() && !e.is_thinking() && e.message().role != Role::System)
|
||||
.map(|e| {
|
||||
let tokens = if e.is_log() { 0 } else {
|
||||
context::msg_token_count(&self.tokenizer, e.api_message())
|
||||
};
|
||||
let timestamp = if e.is_log() { None } else {
|
||||
let timestamp = if e.is_log() || e.is_thinking() { None } else {
|
||||
e.message().timestamp.as_ref().and_then(|ts| {
|
||||
chrono::DateTime::parse_from_rfc3339(ts).ok()
|
||||
.map(|dt| dt.with_timezone(&chrono::Utc))
|
||||
})
|
||||
};
|
||||
ContextEntry { entry: e, tokens, timestamp }
|
||||
ContextEntry::new(e, timestamp)
|
||||
})
|
||||
.collect();
|
||||
let mem_count = all.iter().filter(|e| e.entry.is_memory()).count();
|
||||
|
|
|
|||
82
src/agent/tokenizer.rs
Normal file
82
src/agent/tokenizer.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
// tokenizer.rs — Qwen tokenizer for direct token generation
|
||||
//
|
||||
// Loads the HuggingFace tokenizer.json for the target model and provides
|
||||
// tokenization for context entries. The tokenizer is loaded once globally
|
||||
// and shared across all callers.
|
||||
//
|
||||
// Token IDs include the chat template wrapping:
|
||||
// <|im_start|>role\ncontent<|im_end|>\n
|
||||
// so concatenating token_ids across entries produces a ready-to-send prompt.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
static TOKENIZER: OnceLock<Tokenizer> = OnceLock::new();
|
||||
|
||||
/// Special token IDs for Qwen 3.5
|
||||
pub const IM_START: u32 = 248045;
|
||||
pub const IM_END: u32 = 248046;
|
||||
|
||||
/// Initialize the global tokenizer from a file path.
|
||||
/// Call once at startup. Panics if the file can't be loaded.
|
||||
pub fn init(path: &str) {
|
||||
let t = Tokenizer::from_file(path)
|
||||
.unwrap_or_else(|e| panic!("failed to load tokenizer from {}: {}", path, e));
|
||||
TOKENIZER.set(t).ok();
|
||||
}
|
||||
|
||||
/// Get the global tokenizer. Panics if not initialized.
|
||||
fn get() -> &'static Tokenizer {
|
||||
TOKENIZER.get().expect("tokenizer not initialized — call tokenizer::init() first")
|
||||
}
|
||||
|
||||
/// Tokenize a raw string, returning token IDs.
|
||||
pub fn encode(text: &str) -> Vec<u32> {
|
||||
get().encode(text, false)
|
||||
.unwrap_or_else(|e| panic!("tokenization failed: {}", e))
|
||||
.get_ids()
|
||||
.to_vec()
|
||||
}
|
||||
|
||||
/// Tokenize a chat entry with template wrapping:
|
||||
/// <|im_start|>role\ncontent<|im_end|>\n
|
||||
/// Returns the complete token ID sequence for this entry.
|
||||
pub fn tokenize_entry(role: &str, content: &str) -> Vec<u32> {
|
||||
let mut ids = Vec::new();
|
||||
ids.push(IM_START);
|
||||
ids.extend(encode(role));
|
||||
ids.extend(encode("\n"));
|
||||
ids.extend(encode(content));
|
||||
ids.push(IM_END);
|
||||
ids.extend(encode("\n"));
|
||||
ids
|
||||
}
|
||||
|
||||
/// Count tokens for a string (convenience for budget checks).
|
||||
pub fn count(text: &str) -> usize {
|
||||
encode(text).len()
|
||||
}
|
||||
|
||||
/// Decode token IDs back to text.
|
||||
pub fn decode(ids: &[u32]) -> String {
|
||||
get().decode(ids, true)
|
||||
.unwrap_or_else(|e| panic!("detokenization failed: {}", e))
|
||||
}
|
||||
|
||||
/// Check if the tokenizer is initialized.
|
||||
pub fn is_initialized() -> bool {
|
||||
TOKENIZER.get().is_some()
|
||||
}
|
||||
|
||||
/// Tokenize a ConversationEntry with its role and content.
|
||||
pub fn tokenize_conv_entry(entry: &super::context::ConversationEntry) -> Vec<u32> {
|
||||
use super::context::ConversationEntry;
|
||||
match entry {
|
||||
ConversationEntry::System(m) => tokenize_entry("system", m.content_text()),
|
||||
ConversationEntry::Message(m) => tokenize_entry(m.role_str(), m.content_text()),
|
||||
ConversationEntry::Memory { message, .. } => tokenize_entry("memory", message.content_text()),
|
||||
ConversationEntry::Dmn(m) => tokenize_entry("dmn", m.content_text()),
|
||||
ConversationEntry::Thinking(text) => tokenize_entry("thinking", text),
|
||||
ConversationEntry::Log(_) => vec![], // logs don't consume tokens
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue