Add Ast trait for render/token_ids/tokens
Implemented by both AstNode and ContextState, so anything that needs "give me the prompt" can take impl Ast. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
f1397b7783
commit
942144949d
1 changed files with 56 additions and 48 deletions
|
|
@ -95,6 +95,12 @@ pub enum Section {
|
||||||
Conversation,
|
Conversation,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait Ast {
|
||||||
|
fn render(&self) -> String;
|
||||||
|
fn token_ids(&self) -> Vec<u32>;
|
||||||
|
fn tokens(&self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
/// State machine for parsing a streaming assistant response into an AstNode.
|
/// State machine for parsing a streaming assistant response into an AstNode.
|
||||||
/// Feed text chunks as they arrive; completed tool calls are returned for
|
/// Feed text chunks as they arrive; completed tool calls are returned for
|
||||||
/// immediate dispatch.
|
/// immediate dispatch.
|
||||||
|
|
@ -243,32 +249,6 @@ impl AstNode {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- Accessors ------------------------------------------------------------
|
|
||||||
|
|
||||||
pub fn tokens(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
Self::Leaf(leaf) => leaf.tokens(),
|
|
||||||
Self::Branch { children, .. } =>
|
|
||||||
children.iter().map(|c| c.tokens()).sum(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn token_ids(&self) -> Vec<u32> {
|
|
||||||
match self {
|
|
||||||
Self::Leaf(leaf) => leaf.token_ids.clone(),
|
|
||||||
Self::Branch { role, children } =>
|
|
||||||
tokenize_branch(*role, children),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn render(&self) -> String {
|
|
||||||
match self {
|
|
||||||
Self::Leaf(leaf) => leaf.body.render(),
|
|
||||||
Self::Branch { role, children } =>
|
|
||||||
render_branch(*role, children),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn children(&self) -> &[AstNode] {
|
pub fn children(&self) -> &[AstNode] {
|
||||||
match self {
|
match self {
|
||||||
Self::Branch { children, .. } => children,
|
Self::Branch { children, .. } => children,
|
||||||
|
|
@ -314,6 +294,32 @@ impl AstNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Ast for AstNode {
|
||||||
|
fn render(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::Leaf(leaf) => leaf.body.render(),
|
||||||
|
Self::Branch { role, children } =>
|
||||||
|
render_branch(*role, children),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_ids(&self) -> Vec<u32> {
|
||||||
|
match self {
|
||||||
|
Self::Leaf(leaf) => leaf.token_ids.clone(),
|
||||||
|
Self::Branch { role, children } =>
|
||||||
|
tokenizer::encode(&render_branch(*role, children)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokens(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Self::Leaf(leaf) => leaf.tokens(),
|
||||||
|
Self::Branch { children, .. } =>
|
||||||
|
children.iter().map(|c| c.tokens()).sum(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn truncate_preview(s: &str, max: usize) -> String {
|
fn truncate_preview(s: &str, max: usize) -> String {
|
||||||
let preview: String = s.chars().take(max).collect();
|
let preview: String = s.chars().take(max).collect();
|
||||||
let preview = preview.replace('\n', " ");
|
let preview = preview.replace('\n', " ");
|
||||||
|
|
@ -329,10 +335,6 @@ fn render_branch(role: Role, children: &[AstNode]) -> String {
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec<u32> {
|
|
||||||
tokenizer::encode(&render_branch(role, children))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||||||
let args: serde_json::Value = serde_json::from_str(args_json)
|
let args: serde_json::Value = serde_json::from_str(args_json)
|
||||||
.unwrap_or(serde_json::Value::Object(Default::default()));
|
.unwrap_or(serde_json::Value::Object(Default::default()));
|
||||||
|
|
@ -567,26 +569,15 @@ impl ContextState {
|
||||||
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
||||||
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
||||||
|
|
||||||
pub fn tokens(&self) -> usize {
|
fn sections(&self) -> [&Vec<AstNode>; 4] {
|
||||||
self.system.iter().map(|n| n.tokens()).sum::<usize>()
|
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||||
+ self.identity.iter().map(|n| n.tokens()).sum::<usize>()
|
}
|
||||||
+ self.journal.iter().map(|n| n.tokens()).sum::<usize>()
|
|
||||||
+ self.conversation.iter().map(|n| n.tokens()).sum::<usize>()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn token_ids(&self) -> Vec<u32> {
|
impl Ast for ContextState {
|
||||||
let mut ids = Vec::new();
|
fn render(&self) -> String {
|
||||||
for section in [&self.system, &self.identity, &self.journal, &self.conversation] {
|
|
||||||
for node in section {
|
|
||||||
ids.extend(node.token_ids());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ids
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn render(&self) -> String {
|
|
||||||
let mut s = String::new();
|
let mut s = String::new();
|
||||||
for section in [&self.system, &self.identity, &self.journal, &self.conversation] {
|
for section in self.sections() {
|
||||||
for node in section {
|
for node in section {
|
||||||
s.push_str(&node.render());
|
s.push_str(&node.render());
|
||||||
}
|
}
|
||||||
|
|
@ -594,8 +585,25 @@ impl ContextState {
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- Mutation --------------------------------------------------------------
|
fn token_ids(&self) -> Vec<u32> {
|
||||||
|
let mut ids = Vec::new();
|
||||||
|
for section in self.sections() {
|
||||||
|
for node in section {
|
||||||
|
ids.extend(node.token_ids());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ids
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tokens(&self) -> usize {
|
||||||
|
self.sections().iter()
|
||||||
|
.flat_map(|s| s.iter())
|
||||||
|
.map(|n| n.tokens())
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContextState {
|
||||||
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
||||||
match section {
|
match section {
|
||||||
Section::System => &mut self.system,
|
Section::System => &mut self.system,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue