Add Ast trait for render/token_ids/tokens

Implemented by both AstNode and ContextState, so anything that
needs "give me the prompt" can take impl Ast.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-08 13:38:00 -04:00
parent f1397b7783
commit 942144949d

View file

@ -95,6 +95,12 @@ pub enum Section {
Conversation, Conversation,
} }
pub trait Ast {
fn render(&self) -> String;
fn token_ids(&self) -> Vec<u32>;
fn tokens(&self) -> usize;
}
/// State machine for parsing a streaming assistant response into an AstNode. /// State machine for parsing a streaming assistant response into an AstNode.
/// Feed text chunks as they arrive; completed tool calls are returned for /// Feed text chunks as they arrive; completed tool calls are returned for
/// immediate dispatch. /// immediate dispatch.
@ -243,32 +249,6 @@ impl AstNode {
self self
} }
// -- Accessors ------------------------------------------------------------
pub fn tokens(&self) -> usize {
match self {
Self::Leaf(leaf) => leaf.tokens(),
Self::Branch { children, .. } =>
children.iter().map(|c| c.tokens()).sum(),
}
}
pub fn token_ids(&self) -> Vec<u32> {
match self {
Self::Leaf(leaf) => leaf.token_ids.clone(),
Self::Branch { role, children } =>
tokenize_branch(*role, children),
}
}
pub fn render(&self) -> String {
match self {
Self::Leaf(leaf) => leaf.body.render(),
Self::Branch { role, children } =>
render_branch(*role, children),
}
}
pub fn children(&self) -> &[AstNode] { pub fn children(&self) -> &[AstNode] {
match self { match self {
Self::Branch { children, .. } => children, Self::Branch { children, .. } => children,
@ -314,6 +294,32 @@ impl AstNode {
} }
} }
impl Ast for AstNode {
fn render(&self) -> String {
match self {
Self::Leaf(leaf) => leaf.body.render(),
Self::Branch { role, children } =>
render_branch(*role, children),
}
}
fn token_ids(&self) -> Vec<u32> {
match self {
Self::Leaf(leaf) => leaf.token_ids.clone(),
Self::Branch { role, children } =>
tokenizer::encode(&render_branch(*role, children)),
}
}
fn tokens(&self) -> usize {
match self {
Self::Leaf(leaf) => leaf.tokens(),
Self::Branch { children, .. } =>
children.iter().map(|c| c.tokens()).sum(),
}
}
}
fn truncate_preview(s: &str, max: usize) -> String { fn truncate_preview(s: &str, max: usize) -> String {
let preview: String = s.chars().take(max).collect(); let preview: String = s.chars().take(max).collect();
let preview = preview.replace('\n', " "); let preview = preview.replace('\n', " ");
@ -329,10 +335,6 @@ fn render_branch(role: Role, children: &[AstNode]) -> String {
s s
} }
fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec<u32> {
tokenizer::encode(&render_branch(role, children))
}
fn format_tool_call_xml(name: &str, args_json: &str) -> String { fn format_tool_call_xml(name: &str, args_json: &str) -> String {
let args: serde_json::Value = serde_json::from_str(args_json) let args: serde_json::Value = serde_json::from_str(args_json)
.unwrap_or(serde_json::Value::Object(Default::default())); .unwrap_or(serde_json::Value::Object(Default::default()));
@ -567,26 +569,15 @@ impl ContextState {
pub fn journal(&self) -> &[AstNode] { &self.journal } pub fn journal(&self) -> &[AstNode] { &self.journal }
pub fn conversation(&self) -> &[AstNode] { &self.conversation } pub fn conversation(&self) -> &[AstNode] { &self.conversation }
pub fn tokens(&self) -> usize { fn sections(&self) -> [&Vec<AstNode>; 4] {
self.system.iter().map(|n| n.tokens()).sum::<usize>() [&self.system, &self.identity, &self.journal, &self.conversation]
+ self.identity.iter().map(|n| n.tokens()).sum::<usize>()
+ self.journal.iter().map(|n| n.tokens()).sum::<usize>()
+ self.conversation.iter().map(|n| n.tokens()).sum::<usize>()
} }
}
pub fn token_ids(&self) -> Vec<u32> { impl Ast for ContextState {
let mut ids = Vec::new(); fn render(&self) -> String {
for section in [&self.system, &self.identity, &self.journal, &self.conversation] {
for node in section {
ids.extend(node.token_ids());
}
}
ids
}
pub fn render(&self) -> String {
let mut s = String::new(); let mut s = String::new();
for section in [&self.system, &self.identity, &self.journal, &self.conversation] { for section in self.sections() {
for node in section { for node in section {
s.push_str(&node.render()); s.push_str(&node.render());
} }
@ -594,8 +585,25 @@ impl ContextState {
s s
} }
// -- Mutation -------------------------------------------------------------- fn token_ids(&self) -> Vec<u32> {
let mut ids = Vec::new();
for section in self.sections() {
for node in section {
ids.extend(node.token_ids());
}
}
ids
}
fn tokens(&self) -> usize {
self.sections().iter()
.flat_map(|s| s.iter())
.map(|n| n.tokens())
.sum()
}
}
impl ContextState {
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> { fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
match section { match section {
Section::System => &mut self.system, Section::System => &mut self.system,