WIP: Context AST design — AstNode with Leaf{text,token_ids}/Branch

New context_new.rs with the AST-based context window design:
- AstNode: role + NodeBody (Leaf with text+token_ids, or Branch with children)
- Tokens only on leaves, branches walk children
- render() produces UTF-8, tokenize produces token IDs, same path
- ResponseParser state machine for streaming assistant responses
- Role enum covers all node types including sections

Still needs: fix remaining pattern match issues, add ContextState wrapper,
wire into mod.rs, replace old context.rs.

Does not compile yet — this is a design checkpoint.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-08 12:46:44 -04:00
parent 64157d8fd7
commit 29dc339f54
2 changed files with 524 additions and 0 deletions

523
src/agent/context_new.rs Normal file
View file

@ -0,0 +1,523 @@
// context.rs — Context window as an AST
//
// The context window is a tree of AstNodes. Each node has a role, a body
// (either leaf content or children), and cached token IDs. The full prompt
// for the model is a depth-first traversal. Streaming responses are parsed
// into new nodes by the ResponseParser.
//
// Grammar (EBNF):
//
// context = section* ;
// section = message* ;
// message = IM_START role "\n" body IM_END "\n" ;
// body = content | element* ; (* leaf or branch *)
// element = thinking | tool_call | content ;
// thinking = "<think>" content "</think>" ;
// tool_call = "<tool_call>\n" tool_xml "\n</tool_call>" ;
// tool_xml = "<function=" NAME ">\n" param* "</function>" ;
// param = "<parameter=" NAME ">\n" VALUE "\n</parameter>\n" ;
// content = TOKEN* ;
//
// The AST is uniform: one AstNode type for everything. The grammar
// constraints are enforced by construction, not by the type system.
use chrono::{DateTime, Utc};
use super::tokenizer;
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
// Sections (top-level groupings)
SystemSection,
IdentitySection,
JournalSection,
ConversationSection,
// Messages (top-level, get im_start/im_end wrapping)
System,
User,
Assistant,
Tool,
Dmn,
Memory,
Log,
// Children of Assistant
Thinking,
ToolCall,
Content,
}
#[derive(Debug, Clone)]
pub enum NodeBody {
Leaf { text: String, token_ids: Vec<u32> },
Branch(Vec<AstNode>),
}
#[derive(Debug, Clone)]
pub struct AstNode {
pub role: Role,
pub body: NodeBody,
pub timestamp: Option<DateTime<Utc>>,
// Optional metadata
pub memory_key: Option<String>,
pub memory_score: Option<f64>,
pub tool_name: Option<String>,
pub tool_args: Option<String>,
pub tool_call_id: Option<String>,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Self::System | Self::SystemSection => "system",
Self::User | Self::IdentitySection | Self::JournalSection => "user",
Self::Assistant => "assistant",
Self::Tool => "tool",
Self::Dmn => "dmn",
Self::Memory => "memory",
Self::Log => "log",
Self::Thinking => "thinking",
Self::ToolCall => "tool_call",
Self::Content => "content",
Self::ConversationSection => "conversation",
}
}
/// Whether this node contributes tokens to the prompt.
pub fn is_prompt_visible(&self) -> bool {
!matches!(self, Self::Thinking | Self::Log)
}
/// Whether this is a top-level message (gets im_start/im_end wrapping).
pub fn is_message(&self) -> bool {
matches!(self, Self::System | Self::User | Self::Assistant |
Self::Tool | Self::Dmn | Self::Memory | Self::Log)
}
/// Whether this is a section (container for messages).
pub fn is_section(&self) -> bool {
matches!(self, Self::SystemSection | Self::IdentitySection |
Self::JournalSection | Self::ConversationSection)
}
}
impl AstNode {
/// Create a leaf node (content string, no children).
pub fn leaf(role: Role, content: impl Into<String>) -> Self {
let content = content.into();
let token_ids = tokenize_leaf(role, &content);
Self {
role, body: NodeBody::Leaf { text: content, token_ids },
timestamp: None, memory_key: None, memory_score: None,
tool_name: None, tool_args: None, tool_call_id: None,
}
}
/// Create a branch node (children, no direct content).
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
Self {
role, body: NodeBody::Branch(children),
timestamp: None, memory_key: None, memory_score: None,
tool_name: None, tool_args: None, tool_call_id: None,
}
}
/// Create a memory node.
pub fn memory(key: impl Into<String>, content: impl Into<String>) -> Self {
let mut node = Self::leaf(Role::Memory, content);
node.memory_key = Some(key.into());
node
}
/// Create a tool call node.
pub fn tool_call(id: impl Into<String>, name: impl Into<String>, args: impl Into<String>) -> Self {
let name = name.into();
let args = args.into();
let id = id.into();
// Format the XML body for tokenization
let xml = format_tool_call_xml(&name, &args);
let mut node = Self::leaf(Role::ToolCall, xml);
node.tool_name = Some(name);
node.tool_args = Some(args);
node.tool_call_id = Some(id);
node
}
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
self.timestamp = Some(ts);
self
}
/// Token count — leaf returns cached len, branch sums children recursively.
pub fn tokens(&self) -> usize {
match &self.body {
NodeBody::Leaf { token_ids, .. } => token_ids.len(),
NodeBody::Branch(children) => children.iter().map(|c| c.tokens()).sum(),
}
}
/// Get token IDs — leaf returns cached, branch walks children.
pub fn token_ids(&self) -> Vec<u32> {
match &self.body {
NodeBody::Leaf { token_ids, .. } => token_ids.clone(),
NodeBody::Branch(children) => {
tokenize_branch(self.role, children)
}
}
}
/// Render this node to text (same output as tokenization, but UTF-8).
pub fn render(&self) -> String {
match &self.body {
NodeBody::Leaf { text, .. } => render_leaf(self.role, text),
NodeBody::Branch(children) => render_branch(self.role, children),
}
}
/// Get the content string (leaf nodes only).
pub fn content(&self) -> &str {
match &self.body {
NodeBody::Leaf { text, .. } => text,
NodeBody::Branch(_) => "",
}
}
/// Get children (branch nodes only).
pub fn children(&self) -> &[AstNode] {
match &self.body {
NodeBody::Branch(c) => c,
NodeBody::Leaf { .. } => &[],
}
}
/// Get mutable children.
pub fn children_mut(&mut self) -> Option<&mut Vec<AstNode>> {
match &mut self.body {
NodeBody::Branch(c) => Some(c),
NodeBody::Leaf { .. } => None,
}
}
/// Push a child node. Only valid on Branch nodes.
pub fn push_child(&mut self, child: AstNode) {
match &mut self.body {
NodeBody::Branch(children) => children.push(child),
NodeBody::Leaf { .. } => panic!("push_child on leaf node"),
}
}
/// Set score on a Memory node.
pub fn set_score(&mut self, score: Option<f64>) {
self.memory_score = score;
}
/// Short label for the UI.
pub fn label(&self) -> String {
let cfg = crate::config::get();
match self.role {
Role::System => "system".to_string(),
Role::User => format!("{}: {}", cfg.user_name,
truncate_preview(self.content(), 60)),
Role::Assistant => format!("{}: {}", cfg.assistant_name,
truncate_preview(self.content(), 60)),
Role::Tool => "tool_result".to_string(),
Role::Dmn => "dmn".to_string(),
Role::Memory => match (&self.memory_key, self.memory_score) {
(Some(key), Some(s)) => format!("mem: {} score:{:.1}", key, s),
(Some(key), None) => format!("mem: {}", key),
_ => "mem".to_string(),
},
Role::Thinking => format!("thinking: {}",
truncate_preview(self.content(), 60)),
Role::ToolCall => format!("tool_call: {}",
self.tool_name.as_deref().unwrap_or("?")),
Role::Content => truncate_preview(self.content(), 60),
Role::Log => format!("log: {}",
truncate_preview(self.content(), 60)),
_ => self.role.as_str().to_string(),
}
}
}
fn truncate_preview(s: &str, max: usize) -> String {
let preview: String = s.chars().take(max).collect();
let preview = preview.replace('\n', " ");
if s.len() > max { format!("{}...", preview) } else { preview }
}
// ---------------------------------------------------------------------------
// Serialization — two modes, same output
// ---------------------------------------------------------------------------
/// Render a leaf node to text.
fn render_leaf(role: Role, content: &str) -> String {
if !role.is_prompt_visible() {
return String::new();
}
if role.is_message() {
format!("<|im_start|>{}\n{}<|im_end|>\n", role.as_str(), content)
} else {
match role {
Role::Thinking => format!("<think>\n{}\n</think>\n", content),
Role::ToolCall => format!("<tool_call>\n{}\n</tool_call>\n", content),
Role::Content => content.to_string(),
_ => content.to_string(),
}
}
}
/// Render a branch node to text.
fn render_branch(role: Role, children: &[AstNode]) -> String {
if !role.is_prompt_visible() {
return String::new();
}
if role.is_section() {
children.iter().map(|c| c.render()).collect()
} else if role == Role::Assistant {
let mut s = String::from("<|im_start|>assistant\n");
for child in children {
s.push_str(&child.render());
}
s.push_str("<|im_end|>\n");
s
} else {
children.iter().map(|c| c.render()).collect()
}
}
/// Tokenize a leaf node.
fn tokenize_leaf(role: Role, content: &str) -> Vec<u32> {
if !role.is_prompt_visible() {
return vec![];
}
tokenizer::encode(&render_leaf(role, content))
}
/// Tokenize a branch node from its children.
fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec<u32> {
if !role.is_prompt_visible() {
return vec![];
}
tokenizer::encode(&render_branch(role, children))
}
/// Format a tool call as the XML body expected by Qwen.
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
let args: serde_json::Value = serde_json::from_str(args_json)
.unwrap_or(serde_json::Value::Object(Default::default()));
let mut xml = format!("<function={}>\n", name);
if let Some(obj) = args.as_object() {
for (key, value) in obj {
let val_str = match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
xml.push_str(&format!("<parameter={}>\n{}\n</parameter>\n", key, val_str));
}
}
xml.push_str("</function>");
xml
}
// ---------------------------------------------------------------------------
// Streaming response parser
// ---------------------------------------------------------------------------
/// State machine for parsing a streaming assistant response into an AstNode.
/// Feed text chunks as they arrive; completed tool calls are returned for
/// immediate dispatch.
pub struct ResponseParser {
/// Buffered text not yet committed to a node
buf: String,
/// Content fragments collected so far (between tags)
content_parts: Vec<String>,
/// Completed child nodes
children: Vec<AstNode>,
/// Parse state
in_think: bool,
think_buf: String,
in_tool_call: bool,
tool_call_buf: String,
}
impl ResponseParser {
pub fn new() -> Self {
Self {
buf: String::new(),
content_parts: Vec::new(),
children: Vec::new(),
in_think: false,
think_buf: String::new(),
in_tool_call: false,
tool_call_buf: String::new(),
}
}
/// Feed a text chunk. Returns newly completed tool call nodes
/// (for immediate dispatch).
pub fn feed(&mut self, text: &str) -> Vec<AstNode> {
let mut new_calls = vec![];
self.buf.push_str(text);
loop {
if self.in_think {
match self.buf.find("</think>") {
Some(end) => {
self.think_buf.push_str(&self.buf[..end]);
self.buf = self.buf[end + 8..].to_string();
self.in_think = false;
self.children.push(AstNode::leaf(Role::Thinking, &self.think_buf));
self.think_buf.clear();
continue;
}
None => {
self.think_buf.push_str(&self.buf);
self.buf.clear();
break;
}
}
}
if self.in_tool_call {
match self.buf.find("</tool_call>") {
Some(end) => {
self.tool_call_buf.push_str(&self.buf[..end]);
self.buf = self.buf[end + 12..].to_string();
self.in_tool_call = false;
if let Some(call) = super::api::parsing::parse_tool_call_body(&self.tool_call_buf) {
let node = AstNode {
role: Role::ToolCall,
body: NodeBody::Leaf(self.tool_call_buf.clone()),
token_ids: vec![], // tokenized when attached to parent
timestamp: None,
memory_key: None, memory_score: None,
tool_name: Some(call.function.name),
tool_args: Some(call.function.arguments),
tool_call_id: Some(call.id),
};
new_calls.push(node.clone());
self.flush_content();
self.children.push(node);
}
self.tool_call_buf.clear();
continue;
}
None => {
self.tool_call_buf.push_str(&self.buf);
self.buf.clear();
break;
}
}
}
// Look for tag openings
let think_pos = self.buf.find("<think>");
let tool_pos = self.buf.find("<tool_call>");
let next_tag = match (think_pos, tool_pos) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
match next_tag {
Some(pos) => {
// Content before the tag
if pos > 0 {
self.content_parts.push(self.buf[..pos].to_string());
}
if self.buf[pos..].starts_with("<think>") {
self.buf = self.buf[pos + 7..].to_string();
self.flush_content();
self.in_think = true;
} else {
self.buf = self.buf[pos + 11..].to_string();
self.flush_content();
self.in_tool_call = true;
}
continue;
}
None => {
// No complete tag. Keep last 11 chars as lookahead
// (length of "<tool_call>") to handle partial tags.
let safe = self.buf.len().saturating_sub(11);
if safe > 0 {
let safe = self.buf.floor_char_boundary(safe);
self.content_parts.push(self.buf[..safe].to_string());
self.buf = self.buf[safe..].to_string();
}
break;
}
}
}
new_calls
}
/// Flush accumulated content into a Content child node.
fn flush_content(&mut self) {
if !self.content_parts.is_empty() {
let text: String = self.content_parts.drain(..).collect();
if !text.is_empty() {
self.children.push(AstNode::leaf(Role::Content, text));
}
}
}
/// Finalize the parse. Returns the completed assistant AstNode.
pub fn finish(mut self) -> AstNode {
// Remaining buffer is content
if !self.buf.is_empty() {
self.content_parts.push(std::mem::take(&mut self.buf));
}
self.flush_content();
AstNode::branch(Role::Assistant, self.children)
}
/// Get the current display text (for streaming to UI).
/// Returns content accumulated since the last call.
pub fn display_content(&self) -> String {
self.content_parts.join("")
}
}
// ---------------------------------------------------------------------------
// Context window size
// ---------------------------------------------------------------------------
/// Context window size in tokens (from config).
pub fn context_window() -> usize {
crate::config::get().api_context_window
}
/// Context budget in tokens: 80% of the model's context window.
pub fn context_budget_tokens() -> usize {
context_window() * 80 / 100
}
/// Detect context window overflow errors from the API.
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
let msg = err.to_string().to_lowercase();
msg.contains("context length")
|| msg.contains("token limit")
|| msg.contains("too many tokens")
|| msg.contains("maximum context")
|| msg.contains("prompt is too long")
|| msg.contains("request too large")
|| msg.contains("input validation error")
|| msg.contains("content length limit")
|| (msg.contains("400") && msg.contains("tokens"))
}
/// Detect model/provider errors delivered inside the SSE stream.
pub fn is_stream_error(err: &anyhow::Error) -> bool {
err.to_string().contains("model stream error")
}

View file

@ -15,6 +15,7 @@
pub mod api; pub mod api;
pub mod context; pub mod context;
pub mod context_new;
pub mod oneshot; pub mod oneshot;
pub mod tokenizer; pub mod tokenizer;
pub mod tools; pub mod tools;