From 942144949d8d3008311a607fd8dd524e375729d7 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Wed, 8 Apr 2026 13:38:00 -0400 Subject: [PATCH] Add Ast trait for render/token_ids/tokens Implemented by both AstNode and ContextState, so anything that needs "give me the prompt" can take impl Ast. Co-Authored-By: Proof of Concept --- src/agent/context_new.rs | 104 +++++++++++++++++++++------------------ 1 file changed, 56 insertions(+), 48 deletions(-) diff --git a/src/agent/context_new.rs b/src/agent/context_new.rs index 35325f0..8bf95bf 100644 --- a/src/agent/context_new.rs +++ b/src/agent/context_new.rs @@ -95,6 +95,12 @@ pub enum Section { Conversation, } +pub trait Ast { + fn render(&self) -> String; + fn token_ids(&self) -> Vec; + fn tokens(&self) -> usize; +} + /// State machine for parsing a streaming assistant response into an AstNode. /// Feed text chunks as they arrive; completed tool calls are returned for /// immediate dispatch. @@ -243,32 +249,6 @@ impl AstNode { self } - // -- Accessors ------------------------------------------------------------ - - pub fn tokens(&self) -> usize { - match self { - Self::Leaf(leaf) => leaf.tokens(), - Self::Branch { children, .. } => - children.iter().map(|c| c.tokens()).sum(), - } - } - - pub fn token_ids(&self) -> Vec { - match self { - Self::Leaf(leaf) => leaf.token_ids.clone(), - Self::Branch { role, children } => - tokenize_branch(*role, children), - } - } - - pub fn render(&self) -> String { - match self { - Self::Leaf(leaf) => leaf.body.render(), - Self::Branch { role, children } => - render_branch(*role, children), - } - } - pub fn children(&self) -> &[AstNode] { match self { Self::Branch { children, .. } => children, @@ -314,6 +294,32 @@ impl AstNode { } } +impl Ast for AstNode { + fn render(&self) -> String { + match self { + Self::Leaf(leaf) => leaf.body.render(), + Self::Branch { role, children } => + render_branch(*role, children), + } + } + + fn token_ids(&self) -> Vec { + match self { + Self::Leaf(leaf) => leaf.token_ids.clone(), + Self::Branch { role, children } => + tokenizer::encode(&render_branch(*role, children)), + } + } + + fn tokens(&self) -> usize { + match self { + Self::Leaf(leaf) => leaf.tokens(), + Self::Branch { children, .. } => + children.iter().map(|c| c.tokens()).sum(), + } + } +} + fn truncate_preview(s: &str, max: usize) -> String { let preview: String = s.chars().take(max).collect(); let preview = preview.replace('\n', " "); @@ -329,10 +335,6 @@ fn render_branch(role: Role, children: &[AstNode]) -> String { s } -fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec { - tokenizer::encode(&render_branch(role, children)) -} - fn format_tool_call_xml(name: &str, args_json: &str) -> String { let args: serde_json::Value = serde_json::from_str(args_json) .unwrap_or(serde_json::Value::Object(Default::default())); @@ -567,26 +569,15 @@ impl ContextState { pub fn journal(&self) -> &[AstNode] { &self.journal } pub fn conversation(&self) -> &[AstNode] { &self.conversation } - pub fn tokens(&self) -> usize { - self.system.iter().map(|n| n.tokens()).sum::() - + self.identity.iter().map(|n| n.tokens()).sum::() - + self.journal.iter().map(|n| n.tokens()).sum::() - + self.conversation.iter().map(|n| n.tokens()).sum::() + fn sections(&self) -> [&Vec; 4] { + [&self.system, &self.identity, &self.journal, &self.conversation] } +} - pub fn token_ids(&self) -> Vec { - let mut ids = Vec::new(); - for section in [&self.system, &self.identity, &self.journal, &self.conversation] { - for node in section { - ids.extend(node.token_ids()); - } - } - ids - } - - pub fn render(&self) -> String { +impl Ast for ContextState { + fn render(&self) -> String { let mut s = String::new(); - for section in [&self.system, &self.identity, &self.journal, &self.conversation] { + for section in self.sections() { for node in section { s.push_str(&node.render()); } @@ -594,8 +585,25 @@ impl ContextState { s } - // -- Mutation -------------------------------------------------------------- + fn token_ids(&self) -> Vec { + let mut ids = Vec::new(); + for section in self.sections() { + for node in section { + ids.extend(node.token_ids()); + } + } + ids + } + fn tokens(&self) -> usize { + self.sections().iter() + .flat_map(|s| s.iter()) + .map(|n| n.tokens()) + .sum() + } +} + +impl ContextState { fn section_mut(&mut self, section: Section) -> &mut Vec { match section { Section::System => &mut self.system,