Add Ast trait for render/token_ids/tokens
Implemented by both AstNode and ContextState, so anything that needs "give me the prompt" can take impl Ast. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
f1397b7783
commit
942144949d
1 changed files with 56 additions and 48 deletions
|
|
@ -95,6 +95,12 @@ pub enum Section {
|
|||
Conversation,
|
||||
}
|
||||
|
||||
pub trait Ast {
|
||||
fn render(&self) -> String;
|
||||
fn token_ids(&self) -> Vec<u32>;
|
||||
fn tokens(&self) -> usize;
|
||||
}
|
||||
|
||||
/// State machine for parsing a streaming assistant response into an AstNode.
|
||||
/// Feed text chunks as they arrive; completed tool calls are returned for
|
||||
/// immediate dispatch.
|
||||
|
|
@ -243,32 +249,6 @@ impl AstNode {
|
|||
self
|
||||
}
|
||||
|
||||
// -- Accessors ------------------------------------------------------------
|
||||
|
||||
pub fn tokens(&self) -> usize {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.tokens(),
|
||||
Self::Branch { children, .. } =>
|
||||
children.iter().map(|c| c.tokens()).sum(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_ids(&self) -> Vec<u32> {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.token_ids.clone(),
|
||||
Self::Branch { role, children } =>
|
||||
tokenize_branch(*role, children),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(&self) -> String {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.body.render(),
|
||||
Self::Branch { role, children } =>
|
||||
render_branch(*role, children),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn children(&self) -> &[AstNode] {
|
||||
match self {
|
||||
Self::Branch { children, .. } => children,
|
||||
|
|
@ -314,6 +294,32 @@ impl AstNode {
|
|||
}
|
||||
}
|
||||
|
||||
impl Ast for AstNode {
|
||||
fn render(&self) -> String {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.body.render(),
|
||||
Self::Branch { role, children } =>
|
||||
render_branch(*role, children),
|
||||
}
|
||||
}
|
||||
|
||||
fn token_ids(&self) -> Vec<u32> {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.token_ids.clone(),
|
||||
Self::Branch { role, children } =>
|
||||
tokenizer::encode(&render_branch(*role, children)),
|
||||
}
|
||||
}
|
||||
|
||||
fn tokens(&self) -> usize {
|
||||
match self {
|
||||
Self::Leaf(leaf) => leaf.tokens(),
|
||||
Self::Branch { children, .. } =>
|
||||
children.iter().map(|c| c.tokens()).sum(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_preview(s: &str, max: usize) -> String {
|
||||
let preview: String = s.chars().take(max).collect();
|
||||
let preview = preview.replace('\n', " ");
|
||||
|
|
@ -329,10 +335,6 @@ fn render_branch(role: Role, children: &[AstNode]) -> String {
|
|||
s
|
||||
}
|
||||
|
||||
fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec<u32> {
|
||||
tokenizer::encode(&render_branch(role, children))
|
||||
}
|
||||
|
||||
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||||
let args: serde_json::Value = serde_json::from_str(args_json)
|
||||
.unwrap_or(serde_json::Value::Object(Default::default()));
|
||||
|
|
@ -567,26 +569,15 @@ impl ContextState {
|
|||
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
||||
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
||||
|
||||
pub fn tokens(&self) -> usize {
|
||||
self.system.iter().map(|n| n.tokens()).sum::<usize>()
|
||||
+ self.identity.iter().map(|n| n.tokens()).sum::<usize>()
|
||||
+ self.journal.iter().map(|n| n.tokens()).sum::<usize>()
|
||||
+ self.conversation.iter().map(|n| n.tokens()).sum::<usize>()
|
||||
fn sections(&self) -> [&Vec<AstNode>; 4] {
|
||||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn token_ids(&self) -> Vec<u32> {
|
||||
let mut ids = Vec::new();
|
||||
for section in [&self.system, &self.identity, &self.journal, &self.conversation] {
|
||||
for node in section {
|
||||
ids.extend(node.token_ids());
|
||||
}
|
||||
}
|
||||
ids
|
||||
}
|
||||
|
||||
pub fn render(&self) -> String {
|
||||
impl Ast for ContextState {
|
||||
fn render(&self) -> String {
|
||||
let mut s = String::new();
|
||||
for section in [&self.system, &self.identity, &self.journal, &self.conversation] {
|
||||
for section in self.sections() {
|
||||
for node in section {
|
||||
s.push_str(&node.render());
|
||||
}
|
||||
|
|
@ -594,8 +585,25 @@ impl ContextState {
|
|||
s
|
||||
}
|
||||
|
||||
// -- Mutation --------------------------------------------------------------
|
||||
fn token_ids(&self) -> Vec<u32> {
|
||||
let mut ids = Vec::new();
|
||||
for section in self.sections() {
|
||||
for node in section {
|
||||
ids.extend(node.token_ids());
|
||||
}
|
||||
}
|
||||
ids
|
||||
}
|
||||
|
||||
fn tokens(&self) -> usize {
|
||||
self.sections().iter()
|
||||
.flat_map(|s| s.iter())
|
||||
.map(|n| n.tokens())
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
impl ContextState {
|
||||
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
||||
match section {
|
||||
Section::System => &mut self.system,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue