diff --git a/src/agent/context_new.rs b/src/agent/context_new.rs new file mode 100644 index 0000000..9317dbb --- /dev/null +++ b/src/agent/context_new.rs @@ -0,0 +1,523 @@ +// context.rs — Context window as an AST +// +// The context window is a tree of AstNodes. Each node has a role, a body +// (either leaf content or children), and cached token IDs. The full prompt +// for the model is a depth-first traversal. Streaming responses are parsed +// into new nodes by the ResponseParser. +// +// Grammar (EBNF): +// +// context = section* ; +// section = message* ; +// message = IM_START role "\n" body IM_END "\n" ; +// body = content | element* ; (* leaf or branch *) +// element = thinking | tool_call | content ; +// thinking = "" content "" ; +// tool_call = "\n" tool_xml "\n" ; +// tool_xml = "\n" param* "" ; +// param = "\n" VALUE "\n\n" ; +// content = TOKEN* ; +// +// The AST is uniform: one AstNode type for everything. The grammar +// constraints are enforced by construction, not by the type system. + +use chrono::{DateTime, Utc}; +use super::tokenizer; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Role { + // Sections (top-level groupings) + SystemSection, + IdentitySection, + JournalSection, + ConversationSection, + + // Messages (top-level, get im_start/im_end wrapping) + System, + User, + Assistant, + Tool, + Dmn, + Memory, + Log, + + // Children of Assistant + Thinking, + ToolCall, + Content, +} + +#[derive(Debug, Clone)] +pub enum NodeBody { + Leaf { text: String, token_ids: Vec }, + Branch(Vec), +} + +#[derive(Debug, Clone)] +pub struct AstNode { + pub role: Role, + pub body: NodeBody, + pub timestamp: Option>, + + // Optional metadata + pub memory_key: Option, + pub memory_score: Option, + pub tool_name: Option, + pub tool_args: Option, + pub tool_call_id: Option, +} + +impl Role { + pub fn as_str(&self) -> &'static str { + match self { + Self::System | Self::SystemSection => "system", + Self::User | Self::IdentitySection | Self::JournalSection => "user", + Self::Assistant => "assistant", + Self::Tool => "tool", + Self::Dmn => "dmn", + Self::Memory => "memory", + Self::Log => "log", + Self::Thinking => "thinking", + Self::ToolCall => "tool_call", + Self::Content => "content", + Self::ConversationSection => "conversation", + } + } + + /// Whether this node contributes tokens to the prompt. + pub fn is_prompt_visible(&self) -> bool { + !matches!(self, Self::Thinking | Self::Log) + } + + /// Whether this is a top-level message (gets im_start/im_end wrapping). + pub fn is_message(&self) -> bool { + matches!(self, Self::System | Self::User | Self::Assistant | + Self::Tool | Self::Dmn | Self::Memory | Self::Log) + } + + /// Whether this is a section (container for messages). + pub fn is_section(&self) -> bool { + matches!(self, Self::SystemSection | Self::IdentitySection | + Self::JournalSection | Self::ConversationSection) + } +} + +impl AstNode { + /// Create a leaf node (content string, no children). + pub fn leaf(role: Role, content: impl Into) -> Self { + let content = content.into(); + let token_ids = tokenize_leaf(role, &content); + Self { + role, body: NodeBody::Leaf { text: content, token_ids }, + timestamp: None, memory_key: None, memory_score: None, + tool_name: None, tool_args: None, tool_call_id: None, + } + } + + /// Create a branch node (children, no direct content). + pub fn branch(role: Role, children: Vec) -> Self { + Self { + role, body: NodeBody::Branch(children), + timestamp: None, memory_key: None, memory_score: None, + tool_name: None, tool_args: None, tool_call_id: None, + } + } + + /// Create a memory node. + pub fn memory(key: impl Into, content: impl Into) -> Self { + let mut node = Self::leaf(Role::Memory, content); + node.memory_key = Some(key.into()); + node + } + + /// Create a tool call node. + pub fn tool_call(id: impl Into, name: impl Into, args: impl Into) -> Self { + let name = name.into(); + let args = args.into(); + let id = id.into(); + // Format the XML body for tokenization + let xml = format_tool_call_xml(&name, &args); + let mut node = Self::leaf(Role::ToolCall, xml); + node.tool_name = Some(name); + node.tool_args = Some(args); + node.tool_call_id = Some(id); + node + } + + pub fn with_timestamp(mut self, ts: DateTime) -> Self { + self.timestamp = Some(ts); + self + } + + /// Token count — leaf returns cached len, branch sums children recursively. + pub fn tokens(&self) -> usize { + match &self.body { + NodeBody::Leaf { token_ids, .. } => token_ids.len(), + NodeBody::Branch(children) => children.iter().map(|c| c.tokens()).sum(), + } + } + + /// Get token IDs — leaf returns cached, branch walks children. + pub fn token_ids(&self) -> Vec { + match &self.body { + NodeBody::Leaf { token_ids, .. } => token_ids.clone(), + NodeBody::Branch(children) => { + tokenize_branch(self.role, children) + } + } + } + + /// Render this node to text (same output as tokenization, but UTF-8). + pub fn render(&self) -> String { + match &self.body { + NodeBody::Leaf { text, .. } => render_leaf(self.role, text), + NodeBody::Branch(children) => render_branch(self.role, children), + } + } + + /// Get the content string (leaf nodes only). + pub fn content(&self) -> &str { + match &self.body { + NodeBody::Leaf { text, .. } => text, + NodeBody::Branch(_) => "", + } + } + + /// Get children (branch nodes only). + pub fn children(&self) -> &[AstNode] { + match &self.body { + NodeBody::Branch(c) => c, + NodeBody::Leaf { .. } => &[], + } + } + + /// Get mutable children. + pub fn children_mut(&mut self) -> Option<&mut Vec> { + match &mut self.body { + NodeBody::Branch(c) => Some(c), + NodeBody::Leaf { .. } => None, + } + } + + /// Push a child node. Only valid on Branch nodes. + pub fn push_child(&mut self, child: AstNode) { + match &mut self.body { + NodeBody::Branch(children) => children.push(child), + NodeBody::Leaf { .. } => panic!("push_child on leaf node"), + } + } + + /// Set score on a Memory node. + pub fn set_score(&mut self, score: Option) { + self.memory_score = score; + } + + /// Short label for the UI. + pub fn label(&self) -> String { + let cfg = crate::config::get(); + match self.role { + Role::System => "system".to_string(), + Role::User => format!("{}: {}", cfg.user_name, + truncate_preview(self.content(), 60)), + Role::Assistant => format!("{}: {}", cfg.assistant_name, + truncate_preview(self.content(), 60)), + Role::Tool => "tool_result".to_string(), + Role::Dmn => "dmn".to_string(), + Role::Memory => match (&self.memory_key, self.memory_score) { + (Some(key), Some(s)) => format!("mem: {} score:{:.1}", key, s), + (Some(key), None) => format!("mem: {}", key), + _ => "mem".to_string(), + }, + Role::Thinking => format!("thinking: {}", + truncate_preview(self.content(), 60)), + Role::ToolCall => format!("tool_call: {}", + self.tool_name.as_deref().unwrap_or("?")), + Role::Content => truncate_preview(self.content(), 60), + Role::Log => format!("log: {}", + truncate_preview(self.content(), 60)), + _ => self.role.as_str().to_string(), + } + } +} + +fn truncate_preview(s: &str, max: usize) -> String { + let preview: String = s.chars().take(max).collect(); + let preview = preview.replace('\n', " "); + if s.len() > max { format!("{}...", preview) } else { preview } +} + +// --------------------------------------------------------------------------- +// Serialization — two modes, same output +// --------------------------------------------------------------------------- + +/// Render a leaf node to text. +fn render_leaf(role: Role, content: &str) -> String { + if !role.is_prompt_visible() { + return String::new(); + } + + if role.is_message() { + format!("<|im_start|>{}\n{}<|im_end|>\n", role.as_str(), content) + } else { + match role { + Role::Thinking => format!("\n{}\n\n", content), + Role::ToolCall => format!("\n{}\n\n", content), + Role::Content => content.to_string(), + _ => content.to_string(), + } + } +} + +/// Render a branch node to text. +fn render_branch(role: Role, children: &[AstNode]) -> String { + if !role.is_prompt_visible() { + return String::new(); + } + + if role.is_section() { + children.iter().map(|c| c.render()).collect() + } else if role == Role::Assistant { + let mut s = String::from("<|im_start|>assistant\n"); + for child in children { + s.push_str(&child.render()); + } + s.push_str("<|im_end|>\n"); + s + } else { + children.iter().map(|c| c.render()).collect() + } +} + +/// Tokenize a leaf node. +fn tokenize_leaf(role: Role, content: &str) -> Vec { + if !role.is_prompt_visible() { + return vec![]; + } + tokenizer::encode(&render_leaf(role, content)) +} + +/// Tokenize a branch node from its children. +fn tokenize_branch(role: Role, children: &[AstNode]) -> Vec { + if !role.is_prompt_visible() { + return vec![]; + } + tokenizer::encode(&render_branch(role, children)) +} + +/// Format a tool call as the XML body expected by Qwen. +fn format_tool_call_xml(name: &str, args_json: &str) -> String { + let args: serde_json::Value = serde_json::from_str(args_json) + .unwrap_or(serde_json::Value::Object(Default::default())); + let mut xml = format!("\n", name); + if let Some(obj) = args.as_object() { + for (key, value) in obj { + let val_str = match value { + serde_json::Value::String(s) => s.clone(), + other => other.to_string(), + }; + xml.push_str(&format!("\n{}\n\n", key, val_str)); + } + } + xml.push_str(""); + xml +} + +// --------------------------------------------------------------------------- +// Streaming response parser +// --------------------------------------------------------------------------- + +/// State machine for parsing a streaming assistant response into an AstNode. +/// Feed text chunks as they arrive; completed tool calls are returned for +/// immediate dispatch. +pub struct ResponseParser { + /// Buffered text not yet committed to a node + buf: String, + /// Content fragments collected so far (between tags) + content_parts: Vec, + /// Completed child nodes + children: Vec, + /// Parse state + in_think: bool, + think_buf: String, + in_tool_call: bool, + tool_call_buf: String, +} + +impl ResponseParser { + pub fn new() -> Self { + Self { + buf: String::new(), + content_parts: Vec::new(), + children: Vec::new(), + in_think: false, + think_buf: String::new(), + in_tool_call: false, + tool_call_buf: String::new(), + } + } + + /// Feed a text chunk. Returns newly completed tool call nodes + /// (for immediate dispatch). + pub fn feed(&mut self, text: &str) -> Vec { + let mut new_calls = vec![]; + self.buf.push_str(text); + + loop { + if self.in_think { + match self.buf.find("") { + Some(end) => { + self.think_buf.push_str(&self.buf[..end]); + self.buf = self.buf[end + 8..].to_string(); + self.in_think = false; + self.children.push(AstNode::leaf(Role::Thinking, &self.think_buf)); + self.think_buf.clear(); + continue; + } + None => { + self.think_buf.push_str(&self.buf); + self.buf.clear(); + break; + } + } + } + + if self.in_tool_call { + match self.buf.find("") { + Some(end) => { + self.tool_call_buf.push_str(&self.buf[..end]); + self.buf = self.buf[end + 12..].to_string(); + self.in_tool_call = false; + if let Some(call) = super::api::parsing::parse_tool_call_body(&self.tool_call_buf) { + let node = AstNode { + role: Role::ToolCall, + body: NodeBody::Leaf(self.tool_call_buf.clone()), + token_ids: vec![], // tokenized when attached to parent + timestamp: None, + memory_key: None, memory_score: None, + tool_name: Some(call.function.name), + tool_args: Some(call.function.arguments), + tool_call_id: Some(call.id), + }; + new_calls.push(node.clone()); + self.flush_content(); + self.children.push(node); + } + self.tool_call_buf.clear(); + continue; + } + None => { + self.tool_call_buf.push_str(&self.buf); + self.buf.clear(); + break; + } + } + } + + // Look for tag openings + let think_pos = self.buf.find(""); + let tool_pos = self.buf.find(""); + let next_tag = match (think_pos, tool_pos) { + (Some(a), Some(b)) => Some(a.min(b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + }; + + match next_tag { + Some(pos) => { + // Content before the tag + if pos > 0 { + self.content_parts.push(self.buf[..pos].to_string()); + } + if self.buf[pos..].starts_with("") { + self.buf = self.buf[pos + 7..].to_string(); + self.flush_content(); + self.in_think = true; + } else { + self.buf = self.buf[pos + 11..].to_string(); + self.flush_content(); + self.in_tool_call = true; + } + continue; + } + None => { + // No complete tag. Keep last 11 chars as lookahead + // (length of "") to handle partial tags. + let safe = self.buf.len().saturating_sub(11); + if safe > 0 { + let safe = self.buf.floor_char_boundary(safe); + self.content_parts.push(self.buf[..safe].to_string()); + self.buf = self.buf[safe..].to_string(); + } + break; + } + } + } + + new_calls + } + + /// Flush accumulated content into a Content child node. + fn flush_content(&mut self) { + if !self.content_parts.is_empty() { + let text: String = self.content_parts.drain(..).collect(); + if !text.is_empty() { + self.children.push(AstNode::leaf(Role::Content, text)); + } + } + } + + /// Finalize the parse. Returns the completed assistant AstNode. + pub fn finish(mut self) -> AstNode { + // Remaining buffer is content + if !self.buf.is_empty() { + self.content_parts.push(std::mem::take(&mut self.buf)); + } + self.flush_content(); + + AstNode::branch(Role::Assistant, self.children) + } + + /// Get the current display text (for streaming to UI). + /// Returns content accumulated since the last call. + pub fn display_content(&self) -> String { + self.content_parts.join("") + } +} + +// --------------------------------------------------------------------------- +// Context window size +// --------------------------------------------------------------------------- + +/// Context window size in tokens (from config). +pub fn context_window() -> usize { + crate::config::get().api_context_window +} + +/// Context budget in tokens: 80% of the model's context window. +pub fn context_budget_tokens() -> usize { + context_window() * 80 / 100 +} + +/// Detect context window overflow errors from the API. +pub fn is_context_overflow(err: &anyhow::Error) -> bool { + let msg = err.to_string().to_lowercase(); + msg.contains("context length") + || msg.contains("token limit") + || msg.contains("too many tokens") + || msg.contains("maximum context") + || msg.contains("prompt is too long") + || msg.contains("request too large") + || msg.contains("input validation error") + || msg.contains("content length limit") + || (msg.contains("400") && msg.contains("tokens")) +} + +/// Detect model/provider errors delivered inside the SSE stream. +pub fn is_stream_error(err: &anyhow::Error) -> bool { + err.to_string().contains("model stream error") +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 305721a..c279e41 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -15,6 +15,7 @@ pub mod api; pub mod context; +pub mod context_new; pub mod oneshot; pub mod tokenizer; pub mod tools;