diff --git a/src/agent/context.rs b/src/agent/context.rs index 2e54391..c43c023 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -31,8 +31,31 @@ use chrono::{DateTime, Utc}; use serde::{Serialize, Deserialize}; +use std::sync::OnceLock; use super::tokenizer; +// Cached token lengths for role headers — computed once on first use. +// "system\n", "user\n", "assistant\n" and "\n" are fixed strings. +static ROLE_TOKENS: OnceLock<[usize; 3]> = OnceLock::new(); +static NEWLINE_TOKENS: OnceLock = OnceLock::new(); + +fn role_header_tokens(role: Role) -> usize { + let tokens = ROLE_TOKENS.get_or_init(|| [ + tokenizer::encode("system\n").len(), + tokenizer::encode("user\n").len(), + tokenizer::encode("assistant\n").len(), + ]); + match role { + Role::System => tokens[0], + Role::User => tokens[1], + Role::Assistant => tokens[2], + } +} + +fn newline_tokens() -> usize { + *NEWLINE_TOKENS.get_or_init(|| tokenizer::encode("\n").len()) +} + // --------------------------------------------------------------------------- // Types // --------------------------------------------------------------------------- @@ -423,9 +446,9 @@ impl Ast for AstNode { match self { Self::Leaf(leaf) => leaf.tokens(), Self::Branch { role, children, .. } => { - 1 + tokenizer::encode(&format!("{}\n", role.as_str())).len() + 1 + role_header_tokens(*role) + children.iter().map(|c| c.tokens()).sum::() - + 1 + tokenizer::encode("\n").len() + + 1 + newline_tokens() } } }