context: cache role header token lengths

Branch::tokens() was calling tokenizer::encode() on every call for
the role header ("system\n", "user\n", "assistant\n") and trailing
newline. In trim_conversation(), this meant hundreds of encode calls
per trim cycle.

These are fixed strings - cache them with OnceLock on first use.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-12 20:47:36 -04:00
parent ac6f1e9294
commit 72f4f1b617

View file

@ -31,8 +31,31 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use std::sync::OnceLock;
use super::tokenizer; use super::tokenizer;
// Cached token lengths for role headers — computed once on first use.
// "system\n", "user\n", "assistant\n" and "\n" are fixed strings.
static ROLE_TOKENS: OnceLock<[usize; 3]> = OnceLock::new();
static NEWLINE_TOKENS: OnceLock<usize> = OnceLock::new();
fn role_header_tokens(role: Role) -> usize {
let tokens = ROLE_TOKENS.get_or_init(|| [
tokenizer::encode("system\n").len(),
tokenizer::encode("user\n").len(),
tokenizer::encode("assistant\n").len(),
]);
match role {
Role::System => tokens[0],
Role::User => tokens[1],
Role::Assistant => tokens[2],
}
}
fn newline_tokens() -> usize {
*NEWLINE_TOKENS.get_or_init(|| tokenizer::encode("\n").len())
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Types // Types
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -423,9 +446,9 @@ impl Ast for AstNode {
match self { match self {
Self::Leaf(leaf) => leaf.tokens(), Self::Leaf(leaf) => leaf.tokens(),
Self::Branch { role, children, .. } => { Self::Branch { role, children, .. } => {
1 + tokenizer::encode(&format!("{}\n", role.as_str())).len() 1 + role_header_tokens(*role)
+ children.iter().map(|c| c.tokens()).sum::<usize>() + children.iter().map(|c| c.tokens()).sum::<usize>()
+ 1 + tokenizer::encode("\n").len() + 1 + newline_tokens()
} }
} }
} }