forked from kent/consciousness
Images are rendered as `<|vision_start|>` + N × `<|image_pad|>` + `<|vision_end|>` where N is computed from the image dimensions using Qwen3-VL's smart_resize rules (patch_size=16, merge_size=2, min=64K, max=16M pixels). The token count matches what vLLM will produce at request time, so budget accounting stays accurate. Bytes are stored inline on the leaf and base64-encoded in the JSON form. Token IDs are hand-assembled instead of re-running the tokenizer on a potentially-huge placeholder string. Follow-ups: view_image tool rewrite, multi_modal_data on the vLLM request, API-layer plumbing from leaf bytes to request body. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
1564 lines
56 KiB
Rust
1564 lines
56 KiB
Rust
// context.rs — Context window as an AST
|
||
//
|
||
// The context window is a tree of AstNodes. Each node is either a leaf
|
||
// (typed content with cached token IDs) or a branch (role + children).
|
||
// The full prompt is a depth-first traversal of the sections in ContextState.
|
||
// Streaming responses are parsed into new nodes by the ResponseParser.
|
||
//
|
||
// Grammar (EBNF):
|
||
//
|
||
// context = section* ;
|
||
// section = (message | leaf)* ;
|
||
// message = IM_START role "\n" element* IM_END "\n" ;
|
||
// role = "system" | "user" | "assistant" ;
|
||
// element = thinking | tool_call | content ;
|
||
// thinking = "<think>" TEXT "</think>" ;
|
||
// tool_call = "<tool_call>\n" tool_xml "\n</tool_call>" ;
|
||
// tool_xml = "<function=" NAME ">\n" param* "</function>" ;
|
||
// param = "<parameter=" NAME ">\n" VALUE "\n</parameter>\n" ;
|
||
// content = TEXT ;
|
||
//
|
||
// Self-wrapping leaves (not inside a message branch):
|
||
// dmn = IM_START "dmn\n" TEXT IM_END "\n" ;
|
||
// memory = IM_START "memory\n" TEXT IM_END "\n" ;
|
||
// tool_result = IM_START "user\n<tool_response>\n" TEXT "\n</tool_response>" IM_END "\n" ;
|
||
//
|
||
// Non-visible leaves (not in prompt):
|
||
// log = TEXT ;
|
||
//
|
||
// Role is only for branch (interior) nodes. Leaf type is determined by
|
||
// the NodeBody variant. Grammar constraints enforced by construction.
|
||
|
||
use chrono::{DateTime, Utc};
|
||
use serde::{Serialize, Deserialize};
|
||
use std::sync::OnceLock;
|
||
use super::tokenizer;
|
||
|
||
// Cached token lengths for role headers — computed once on first use.
|
||
// "system\n", "user\n", "assistant\n" and "\n" are fixed strings.
|
||
static ROLE_TOKENS: OnceLock<[usize; 3]> = OnceLock::new();
|
||
static NEWLINE_TOKENS: OnceLock<usize> = OnceLock::new();
|
||
|
||
fn role_header_tokens(role: Role) -> usize {
|
||
let tokens = ROLE_TOKENS.get_or_init(|| [
|
||
tokenizer::encode("system\n").len(),
|
||
tokenizer::encode("user\n").len(),
|
||
tokenizer::encode("assistant\n").len(),
|
||
]);
|
||
match role {
|
||
Role::System => tokens[0],
|
||
Role::User => tokens[1],
|
||
Role::Assistant => tokens[2],
|
||
}
|
||
}
|
||
|
||
fn newline_tokens() -> usize {
|
||
*NEWLINE_TOKENS.get_or_init(|| tokenizer::encode("\n").len())
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Types
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Branch roles — maps directly to the grammar's message roles.
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub enum Role {
|
||
System,
|
||
User,
|
||
Assistant,
|
||
}
|
||
|
||
/// Leaf content — each variant knows how to render itself.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub enum NodeBody {
|
||
// Children of message branches — rendered without im_start/im_end
|
||
Content(String),
|
||
Thinking(String),
|
||
ToolCall { name: String, arguments: String },
|
||
|
||
// Self-wrapping leaves — render their own im_start/im_end
|
||
ToolResult(String),
|
||
Memory { key: String, text: String, score: Option<f64> },
|
||
Dmn(String),
|
||
|
||
// Vision input — rendered as <|vision_start|> <|image_pad|>×N <|vision_end|>.
|
||
// `token_count` is N, the count vLLM will compute for this image's grid.
|
||
Image {
|
||
#[serde(with = "b64_bytes")]
|
||
bytes: Vec<u8>,
|
||
mime: String,
|
||
orig_height: u32,
|
||
orig_width: u32,
|
||
token_count: u32,
|
||
},
|
||
|
||
// Non-visible (0 tokens in prompt)
|
||
Log(String),
|
||
}
|
||
|
||
mod b64_bytes {
|
||
use base64::{Engine, engine::general_purpose::STANDARD};
|
||
use serde::{Serializer, Deserializer, Deserialize};
|
||
pub fn serialize<S: Serializer>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error> {
|
||
s.serialize_str(&STANDARD.encode(bytes))
|
||
}
|
||
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
|
||
let s = String::deserialize(d)?;
|
||
STANDARD.decode(s).map_err(serde::de::Error::custom)
|
||
}
|
||
}
|
||
|
||
/// A leaf node: typed content with cached token IDs.
|
||
/// Token IDs are not serialized — they're recomputed on deserialization.
|
||
#[derive(Debug, Clone, Serialize)]
|
||
pub struct NodeLeaf {
|
||
body: NodeBody,
|
||
#[serde(skip)]
|
||
token_ids: Vec<u32>,
|
||
timestamp: DateTime<Utc>,
|
||
}
|
||
|
||
impl<'de> Deserialize<'de> for NodeLeaf {
|
||
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||
#[derive(Deserialize)]
|
||
struct Raw {
|
||
body: NodeBody,
|
||
timestamp: DateTime<Utc>,
|
||
}
|
||
let raw = Raw::deserialize(deserializer)?;
|
||
let token_ids = raw.body.compute_token_ids();
|
||
Ok(NodeLeaf { body: raw.body, token_ids, timestamp: raw.timestamp })
|
||
}
|
||
}
|
||
|
||
/// A node in the context AST.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub enum AstNode {
|
||
Leaf(NodeLeaf),
|
||
Branch {
|
||
role: Role,
|
||
children: Vec<AstNode>,
|
||
timestamp: DateTime<Utc>,
|
||
/// Per-response memory attribution from full scoring matrix.
|
||
/// Maps memory key → divergence score for this response.
|
||
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
|
||
memory_scores: std::collections::BTreeMap<String, f64>,
|
||
},
|
||
}
|
||
|
||
/// The context window: four sections as Vec<AstNode>.
|
||
/// All mutation goes through ContextState methods to maintain the invariant
|
||
/// that token_ids on every leaf matches its rendered text.
|
||
pub struct ContextState {
|
||
system: Vec<AstNode>,
|
||
identity: Vec<AstNode>,
|
||
journal: Vec<AstNode>,
|
||
conversation: Vec<AstNode>,
|
||
pub conversation_log: Option<crate::mind::log::ConversationLog>,
|
||
}
|
||
|
||
impl Clone for ContextState {
|
||
fn clone(&self) -> Self {
|
||
Self {
|
||
system: self.system.clone(),
|
||
identity: self.identity.clone(),
|
||
journal: self.journal.clone(),
|
||
conversation: self.conversation.clone(),
|
||
conversation_log: None, // forked contexts don't log
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Identifies a section for mutation methods.
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
pub enum Section {
|
||
System,
|
||
Identity,
|
||
Journal,
|
||
Conversation,
|
||
}
|
||
|
||
/// Ephemeral handle for dispatching a tool call. Not persisted in the AST.
|
||
#[derive(Debug, Clone)]
|
||
pub struct PendingToolCall {
|
||
pub name: String,
|
||
pub arguments: String,
|
||
pub id: String,
|
||
}
|
||
|
||
pub trait Ast {
|
||
fn render(&self) -> String;
|
||
fn token_ids(&self) -> Vec<u32>;
|
||
fn tokens(&self) -> usize;
|
||
}
|
||
|
||
pub struct ResponseParser {
|
||
branch_idx: usize,
|
||
call_counter: u32,
|
||
buf: String,
|
||
content_parts: Vec<String>,
|
||
in_think: bool,
|
||
think_buf: String,
|
||
in_tool_call: bool,
|
||
tool_call_buf: String,
|
||
}
|
||
|
||
impl Role {
|
||
pub fn as_str(&self) -> &'static str {
|
||
match self {
|
||
Self::System => "system",
|
||
Self::User => "user",
|
||
Self::Assistant => "assistant",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl NodeBody {
|
||
/// Render this leaf body to text for the prompt.
|
||
fn render_into(&self, out: &mut String) {
|
||
match self {
|
||
Self::Content(text) => out.push_str(text),
|
||
Self::Thinking(_) => {},
|
||
Self::Log(_) => {},
|
||
Self::ToolCall { name, arguments } => {
|
||
out.push_str("<tool_call>\n");
|
||
out.push_str(&format_tool_call_xml(name, arguments));
|
||
out.push_str("\n</tool_call>\n");
|
||
}
|
||
Self::ToolResult(text) => {
|
||
out.push_str("<|im_start|>user\n<tool_response>\n");
|
||
out.push_str(text);
|
||
out.push_str("\n</tool_response><|im_end|>\n");
|
||
}
|
||
Self::Memory { text, .. } => {
|
||
out.push_str("<|im_start|>memory\n");
|
||
out.push_str(text);
|
||
out.push_str("<|im_end|>\n");
|
||
}
|
||
Self::Dmn(text) => {
|
||
out.push_str("<|im_start|>dmn\n");
|
||
out.push_str(text);
|
||
out.push_str("<|im_end|>\n");
|
||
}
|
||
Self::Image { token_count, .. } => {
|
||
out.push_str("<|vision_start|>");
|
||
for _ in 0..*token_count {
|
||
out.push_str("<|image_pad|>");
|
||
}
|
||
out.push_str("<|vision_end|>");
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Whether this leaf contributes tokens to the prompt.
|
||
fn render(&self) -> String {
|
||
let mut s = String::new();
|
||
self.render_into(&mut s);
|
||
s
|
||
}
|
||
|
||
fn is_prompt_visible(&self) -> bool {
|
||
!matches!(self, Self::Thinking(_) | Self::Log(_))
|
||
}
|
||
|
||
/// Hand-assemble token IDs for body types where running the tokenizer
|
||
/// on the rendered text would be needlessly expensive (Image). Falls
|
||
/// back to encoding the rendered text for everything else.
|
||
fn compute_token_ids(&self) -> Vec<u32> {
|
||
if !self.is_prompt_visible() {
|
||
return Vec::new();
|
||
}
|
||
match self {
|
||
Self::Image { token_count, .. } => {
|
||
let mut ids = Vec::with_capacity(*token_count as usize + 2);
|
||
ids.push(tokenizer::VISION_START);
|
||
ids.extend(std::iter::repeat(tokenizer::IMAGE_PAD)
|
||
.take(*token_count as usize));
|
||
ids.push(tokenizer::VISION_END);
|
||
ids
|
||
}
|
||
_ => tokenizer::encode(&self.render()),
|
||
}
|
||
}
|
||
|
||
/// The text content of this leaf (for display, not rendering).
|
||
pub fn text(&self) -> &str {
|
||
match self {
|
||
Self::Content(t) | Self::Thinking(t) | Self::Log(t)
|
||
| Self::ToolResult(t) | Self::Dmn(t) => t,
|
||
Self::ToolCall { name, .. } => name,
|
||
Self::Memory { text, .. } => text,
|
||
Self::Image { mime, .. } => mime,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl NodeLeaf {
|
||
fn new(body: NodeBody) -> Self {
|
||
let token_ids = body.compute_token_ids();
|
||
Self { body, token_ids, timestamp: Utc::now() }
|
||
}
|
||
|
||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||
self.timestamp = ts;
|
||
self
|
||
}
|
||
|
||
pub fn body(&self) -> &NodeBody { &self.body }
|
||
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
|
||
pub fn tokens(&self) -> usize { self.token_ids.len() }
|
||
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
|
||
}
|
||
|
||
impl AstNode {
|
||
// -- Leaf constructors ----------------------------------------------------
|
||
|
||
pub fn content(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Content(text.into())))
|
||
}
|
||
|
||
pub fn thinking(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Thinking(text.into())))
|
||
}
|
||
|
||
pub fn tool_call(name: impl Into<String>, arguments: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::ToolCall {
|
||
name: name.into(),
|
||
arguments: arguments.into(),
|
||
}))
|
||
}
|
||
|
||
pub fn tool_result(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::ToolResult(text.into())))
|
||
}
|
||
|
||
pub fn memory(key: impl Into<String>, text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Memory {
|
||
key: key.into(),
|
||
text: text.into(),
|
||
score: None,
|
||
}))
|
||
}
|
||
|
||
pub fn dmn(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Dmn(text.into())))
|
||
}
|
||
|
||
pub fn log(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Log(text.into())))
|
||
}
|
||
|
||
/// Build an Image leaf. `token_count` is computed from the image
|
||
/// dimensions using Qwen3-VL's resizing rules.
|
||
pub fn image(
|
||
bytes: Vec<u8>,
|
||
mime: impl Into<String>,
|
||
orig_height: u32,
|
||
orig_width: u32,
|
||
) -> Self {
|
||
let token_count = qwen3_image_token_count(orig_height, orig_width);
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Image {
|
||
bytes,
|
||
mime: mime.into(),
|
||
orig_height,
|
||
orig_width,
|
||
token_count,
|
||
}))
|
||
}
|
||
|
||
// -- Branch constructors --------------------------------------------------
|
||
|
||
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
|
||
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
|
||
}
|
||
|
||
pub fn system_msg(text: impl Into<String>) -> Self {
|
||
Self::Branch {
|
||
role: Role::System,
|
||
children: vec![Self::content(text)],
|
||
timestamp: Utc::now(),
|
||
memory_scores: Default::default(),
|
||
}
|
||
}
|
||
|
||
pub fn user_msg(text: impl Into<String>) -> Self {
|
||
Self::Branch {
|
||
role: Role::User,
|
||
children: vec![Self::content(text)],
|
||
timestamp: Utc::now(),
|
||
memory_scores: Default::default(),
|
||
}
|
||
}
|
||
|
||
// -- Builder --------------------------------------------------------------
|
||
|
||
pub fn retokenize(self) -> Self {
|
||
match self {
|
||
Self::Leaf(leaf) => {
|
||
let token_ids = leaf.body.compute_token_ids();
|
||
Self::Leaf(NodeLeaf { token_ids, ..leaf })
|
||
}
|
||
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
|
||
role,
|
||
children: children.into_iter().map(|c| c.retokenize()).collect(),
|
||
timestamp,
|
||
memory_scores,
|
||
},
|
||
}
|
||
}
|
||
|
||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||
match &mut self {
|
||
Self::Leaf(leaf) => leaf.timestamp = ts,
|
||
Self::Branch { timestamp, .. } => *timestamp = ts,
|
||
}
|
||
self
|
||
}
|
||
|
||
pub fn children(&self) -> &[AstNode] {
|
||
match self {
|
||
Self::Branch { children, .. } => children,
|
||
Self::Leaf(_) => &[],
|
||
}
|
||
}
|
||
|
||
pub fn leaf(&self) -> Option<&NodeLeaf> {
|
||
match self {
|
||
Self::Leaf(l) => Some(l),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
/// Short label for the UI.
|
||
pub fn label(&self) -> String {
|
||
let app = crate::config::app();
|
||
match self {
|
||
Self::Branch { role, children, .. } => {
|
||
let preview = children.first()
|
||
.and_then(|c| c.leaf())
|
||
.map(|l| truncate_preview(l.body.text(), 60))
|
||
.unwrap_or_default();
|
||
match role {
|
||
Role::System => "system".into(),
|
||
Role::User => format!("{}: {}", app.user_name, preview),
|
||
Role::Assistant => format!("{}: {}", app.assistant_name, preview),
|
||
}
|
||
}
|
||
Self::Leaf(leaf) => match &leaf.body {
|
||
NodeBody::Content(t) => truncate_preview(t, 60),
|
||
NodeBody::Thinking(t) => format!("thinking: {}", truncate_preview(t, 60)),
|
||
NodeBody::ToolCall { name, arguments } => format!("tool: {}({})", name, truncate_preview(arguments, 80)),
|
||
NodeBody::ToolResult(_) => "tool_result".into(),
|
||
NodeBody::Memory { key, score, .. } => match score {
|
||
Some(s) => format!("mem: {} score:{:.1}", key, s),
|
||
None => format!("mem: {}", key),
|
||
},
|
||
NodeBody::Dmn(_) => "dmn".into(),
|
||
NodeBody::Image { orig_height, orig_width, token_count, .. } =>
|
||
format!("image: {}x{} ({} tokens)", orig_width, orig_height, token_count),
|
||
NodeBody::Log(t) => format!("log: {}", truncate_preview(t, 60)),
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
impl AstNode {
|
||
fn render_into(&self, out: &mut String) {
|
||
match self {
|
||
Self::Leaf(leaf) => leaf.body.render_into(out),
|
||
Self::Branch { role, children, .. } => {
|
||
out.push_str(&format!("<|im_start|>{}\n", role.as_str()));
|
||
for child in children {
|
||
child.render_into(out);
|
||
}
|
||
out.push_str("<|im_end|>\n");
|
||
}
|
||
}
|
||
}
|
||
|
||
fn token_ids_into(&self, out: &mut Vec<u32>) {
|
||
match self {
|
||
Self::Leaf(leaf) => out.extend_from_slice(&leaf.token_ids),
|
||
Self::Branch { role, children, .. } => {
|
||
out.push(tokenizer::IM_START);
|
||
out.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||
for child in children {
|
||
child.token_ids_into(out);
|
||
}
|
||
out.push(tokenizer::IM_END);
|
||
out.extend(tokenizer::encode("\n"));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Ast for AstNode {
|
||
fn render(&self) -> String {
|
||
let mut s = String::new();
|
||
self.render_into(&mut s);
|
||
s
|
||
}
|
||
|
||
fn token_ids(&self) -> Vec<u32> {
|
||
let mut ids = Vec::new();
|
||
self.token_ids_into(&mut ids);
|
||
ids
|
||
}
|
||
|
||
fn tokens(&self) -> usize {
|
||
match self {
|
||
Self::Leaf(leaf) => leaf.tokens(),
|
||
Self::Branch { role, children, .. } => {
|
||
1 + role_header_tokens(*role)
|
||
+ children.iter().map(|c| c.tokens()).sum::<usize>()
|
||
+ 1 + newline_tokens()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
fn truncate_preview(s: &str, max: usize) -> String {
|
||
let preview: String = s.chars().take(max).collect();
|
||
let preview = preview.replace('\n', " ");
|
||
if s.len() > max { format!("{}...", preview) } else { preview }
|
||
}
|
||
|
||
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||
let args: serde_json::Value = serde_json::from_str(args_json)
|
||
.unwrap_or(serde_json::Value::Object(Default::default()));
|
||
let mut xml = format!("<function={}>\n", name);
|
||
if let Some(obj) = args.as_object() {
|
||
for (key, value) in obj {
|
||
let val_str = match value {
|
||
serde_json::Value::String(s) => s.clone(),
|
||
other => other.to_string(),
|
||
};
|
||
xml.push_str(&format!("<parameter={}>\n{}\n</parameter>\n", key, val_str));
|
||
}
|
||
}
|
||
xml.push_str("</function>");
|
||
xml
|
||
}
|
||
|
||
/// Search for a sequence of literal parts separated by optional ASCII whitespace.
|
||
/// Returns (start, end) byte positions of the overall match.
|
||
///
|
||
/// Handles the case where streaming tokenization inserts whitespace inside
|
||
/// XML tag structure, e.g. `< function = bash >` instead of `<function=bash>`.
|
||
fn find_ws_seq(s: &str, parts: &[&str]) -> Option<(usize, usize)> {
|
||
let bytes = s.as_bytes();
|
||
let mut search_from = 0;
|
||
'outer: loop {
|
||
let start = s[search_from..].find(parts[0])? + search_from;
|
||
let mut pos = start + parts[0].len();
|
||
for &part in &parts[1..] {
|
||
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
|
||
pos += 1;
|
||
}
|
||
if !s[pos..].starts_with(part) {
|
||
search_from = start + 1;
|
||
continue 'outer;
|
||
}
|
||
pos += part.len();
|
||
}
|
||
return Some((start, pos));
|
||
}
|
||
}
|
||
|
||
/// Parse a Qwen-style XML tag: `<tag=name>body</tag>`.
|
||
/// Tolerates whitespace inside tag delimiters (streaming artifact).
|
||
/// Body content is returned verbatim except for a single leading/trailing
|
||
/// newline (XML formatting convention).
|
||
fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> {
|
||
// Open tag: tolerate whitespace from streaming tokenization
|
||
let (_, after_eq) = find_ws_seq(s, &["<", tag, "="])?;
|
||
let gt_offset = s[after_eq..].find('>')?;
|
||
let name = s[after_eq..after_eq + gt_offset].trim();
|
||
let body_start = after_eq + gt_offset + 1;
|
||
|
||
// Close tag: exact match — model doesn't insert whitespace in close tags
|
||
let close = format!("</{}>", tag);
|
||
let close_offset = s[body_start..].find(&close)?;
|
||
let body = &s[body_start..body_start + close_offset];
|
||
// Strip the single leading/trailing newline from XML formatting,
|
||
// but preserve all other whitespace (indentation matters for code).
|
||
let body = body.strip_prefix('\n').unwrap_or(body);
|
||
let body = body.strip_suffix('\n').unwrap_or(body);
|
||
let rest = &s[body_start + close_offset + close.len()..];
|
||
|
||
Some((name, body, rest))
|
||
}
|
||
|
||
fn parse_tool_call_body(body: &str) -> Option<(String, String)> {
|
||
let body = body.trim();
|
||
parse_xml_tool_call(body)
|
||
.or_else(|| parse_json_tool_call(body))
|
||
}
|
||
|
||
fn parse_xml_tool_call(body: &str) -> Option<(String, String)> {
|
||
let (func_name, func_body, _) = parse_qwen_tag(body, "function")?;
|
||
let mut args = serde_json::Map::new();
|
||
let mut rest = func_body;
|
||
while let Some((key, val, remainder)) = parse_qwen_tag(rest, "parameter") {
|
||
let value = serde_json::from_str(val)
|
||
.unwrap_or(serde_json::Value::String(val.to_string()));
|
||
args.insert(key.to_string(), value);
|
||
rest = remainder;
|
||
}
|
||
Some((func_name.to_string(), serde_json::to_string(&args).unwrap_or_default()))
|
||
}
|
||
|
||
fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
|
||
let v: serde_json::Value = serde_json::from_str(body).ok()?;
|
||
let name = v["name"].as_str()?;
|
||
let arguments = &v["arguments"];
|
||
Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default()))
|
||
}
|
||
|
||
/// Search `buf` for `close_tag`. If found, append everything before it to
|
||
/// `accum`, advance `buf` past the tag, and return the accumulated content.
|
||
/// If not found, drain the safe prefix (preserving any partial tag match at
|
||
/// the end of buf) into `accum`.
|
||
fn scan_close_tag(buf: &mut String, close_tag: &str, accum: &mut String) -> Option<String> {
|
||
if let Some(pos) = buf.find(close_tag) {
|
||
accum.push_str(&buf[..pos]);
|
||
*buf = buf[pos + close_tag.len()..].to_string();
|
||
Some(std::mem::take(accum))
|
||
} else {
|
||
let drained = drain_safe(buf, close_tag.len());
|
||
if !drained.is_empty() {
|
||
accum.push_str(&drained);
|
||
}
|
||
None
|
||
}
|
||
}
|
||
|
||
/// Remove everything from `buf` except the last `tag_len` bytes, which might
|
||
/// be a partial tag. Returns the removed prefix.
|
||
fn drain_safe(buf: &mut String, tag_len: usize) -> String {
|
||
let safe = buf.len().saturating_sub(tag_len);
|
||
if safe > 0 {
|
||
let safe = buf.floor_char_boundary(safe);
|
||
let drained = buf[..safe].to_string();
|
||
*buf = buf[safe..].to_string();
|
||
drained
|
||
} else {
|
||
String::new()
|
||
}
|
||
}
|
||
|
||
impl ResponseParser {
|
||
pub fn new(branch_idx: usize) -> Self {
|
||
Self {
|
||
branch_idx,
|
||
call_counter: 0,
|
||
buf: String::new(),
|
||
content_parts: Vec::new(),
|
||
in_think: false,
|
||
think_buf: String::new(),
|
||
in_tool_call: false,
|
||
tool_call_buf: String::new(),
|
||
}
|
||
}
|
||
|
||
/// Consume a token stream, parse into the AST, yield tool calls.
|
||
/// Spawns a background task. Returns a tool call receiver and a
|
||
/// join handle that resolves to Ok(()) or the stream error.
|
||
pub fn run(
|
||
self,
|
||
mut stream: tokio::sync::mpsc::UnboundedReceiver<super::api::StreamToken>,
|
||
agent: std::sync::Arc<super::Agent>,
|
||
) -> (
|
||
tokio::sync::mpsc::UnboundedReceiver<PendingToolCall>,
|
||
tokio::task::JoinHandle<anyhow::Result<()>>,
|
||
) {
|
||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||
let handle = tokio::spawn(async move {
|
||
let mut parser = self;
|
||
let agent_name = agent.state.lock().await.provenance.clone();
|
||
let log_path = format!("/tmp/poc-{}.log", agent_name);
|
||
let mut log_file = std::fs::OpenOptions::new()
|
||
.create(true).append(true).open(&log_path).ok();
|
||
let mut full_text = String::new();
|
||
while let Some(event) = stream.recv().await {
|
||
match event {
|
||
super::api::StreamToken::Token(id) => {
|
||
let text = super::tokenizer::decode(&[id]);
|
||
full_text.push_str(&text);
|
||
let mut ctx = agent.context.lock().await;
|
||
let calls = parser.feed_token(&text, &mut ctx);
|
||
if !calls.is_empty() {
|
||
if let Some(ref mut f) = log_file {
|
||
use std::io::Write;
|
||
for c in &calls {
|
||
let end = c.arguments.floor_char_boundary(c.arguments.len().min(200));
|
||
let _ = writeln!(f, "tool_call: {} args={}", c.name, &c.arguments[..end]);
|
||
}
|
||
}
|
||
}
|
||
for call in calls {
|
||
let _ = tx.send(call);
|
||
}
|
||
}
|
||
super::api::StreamToken::Done { usage } => {
|
||
if let Some(ref mut f) = log_file {
|
||
use std::io::Write;
|
||
let ctx = agent.context.lock().await;
|
||
let children = ctx.conversation().get(parser.branch_idx)
|
||
.map(|n| n.children()).unwrap_or(&[]);
|
||
let n_think = children.iter().filter(|c| matches!(c.leaf().map(|l| l.body()), Some(NodeBody::Thinking(_)))).count();
|
||
let n_content = children.iter().filter(|c| matches!(c.leaf().map(|l| l.body()), Some(NodeBody::Content(_)))).count();
|
||
let n_tool = children.iter().filter(|c| matches!(c.leaf().map(|l| l.body()), Some(NodeBody::ToolCall { .. }))).count();
|
||
let _ = writeln!(f, "done: {} chars, {} content + {} think + {} tool_call, ctx: {} tokens",
|
||
full_text.len(), n_content, n_think, n_tool, ctx.tokens());
|
||
drop(ctx);
|
||
if full_text.len() > 0 && n_content == 0 && n_tool == 0 {
|
||
let end = full_text.floor_char_boundary(full_text.len().min(2000));
|
||
let _ = writeln!(f, " unparsed text: {}", &full_text[..end]);
|
||
}
|
||
}
|
||
if let Some(u) = usage {
|
||
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
|
||
}
|
||
let mut ctx = agent.context.lock().await;
|
||
parser.finish(&mut ctx);
|
||
return Ok(());
|
||
}
|
||
super::api::StreamToken::Error(e) => {
|
||
return Err(anyhow::anyhow!("{}", e));
|
||
}
|
||
}
|
||
}
|
||
Ok(())
|
||
});
|
||
(rx, handle)
|
||
}
|
||
|
||
pub fn feed_token(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||
const THINK_OPEN: &str = "<think>";
|
||
const THINK_CLOSE: &str = "</think>";
|
||
const TOOL_CALL_OPEN: &str = "<tool_call>";
|
||
const TOOL_CALL_CLOSE: &str = "</tool_call>";
|
||
const OPEN_TAGS: &[&str] = &[THINK_OPEN, TOOL_CALL_OPEN];
|
||
|
||
let mut pending = Vec::new();
|
||
self.buf.push_str(text);
|
||
|
||
loop {
|
||
if self.in_think {
|
||
if let Some(content) = scan_close_tag(&mut self.buf, THINK_CLOSE, &mut self.think_buf) {
|
||
self.in_think = false;
|
||
let text = content.trim().to_string();
|
||
if !text.is_empty() {
|
||
self.push_child(ctx, AstNode::thinking(text));
|
||
}
|
||
continue;
|
||
}
|
||
break;
|
||
}
|
||
|
||
if self.in_tool_call {
|
||
if let Some(content) = scan_close_tag(&mut self.buf, TOOL_CALL_CLOSE, &mut self.tool_call_buf) {
|
||
self.in_tool_call = false;
|
||
if let Some((name, args)) = parse_tool_call_body(&content) {
|
||
self.flush_content(ctx);
|
||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||
self.call_counter += 1;
|
||
pending.push(PendingToolCall {
|
||
name,
|
||
arguments: args,
|
||
id: format!("call_{}", self.call_counter),
|
||
});
|
||
}
|
||
continue;
|
||
}
|
||
break;
|
||
}
|
||
|
||
// Not inside a tag — find the earliest opening tag
|
||
let next = OPEN_TAGS.iter()
|
||
.filter_map(|tag| self.buf.find(tag).map(|pos| (pos, *tag)))
|
||
.min_by_key(|(pos, _)| *pos);
|
||
|
||
match next {
|
||
Some((pos, tag)) => {
|
||
if pos > 0 {
|
||
self.content_parts.push(self.buf[..pos].to_string());
|
||
}
|
||
self.buf = self.buf[pos + tag.len()..].to_string();
|
||
self.flush_content(ctx);
|
||
match tag {
|
||
THINK_OPEN => self.in_think = true,
|
||
TOOL_CALL_OPEN => self.in_tool_call = true,
|
||
_ => unreachable!(),
|
||
}
|
||
continue;
|
||
}
|
||
None => {
|
||
// Keep a tail that might be a partial opening tag
|
||
let max_tag = OPEN_TAGS.iter().map(|t| t.len()).max().unwrap();
|
||
let drained = drain_safe(&mut self.buf, max_tag);
|
||
if !drained.is_empty() {
|
||
self.content_parts.push(drained);
|
||
}
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
pending
|
||
}
|
||
|
||
fn push_child(&self, ctx: &mut ContextState, child: AstNode) {
|
||
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
||
}
|
||
|
||
fn flush_content(&mut self, ctx: &mut ContextState) {
|
||
if !self.content_parts.is_empty() {
|
||
let text: String = self.content_parts.drain(..).collect();
|
||
let text = text.trim().to_string();
|
||
if !text.is_empty() {
|
||
self.push_child(ctx, AstNode::content(text));
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn finish(mut self, ctx: &mut ContextState) {
|
||
if !self.buf.is_empty() {
|
||
self.content_parts.push(std::mem::take(&mut self.buf));
|
||
}
|
||
self.flush_content(ctx);
|
||
}
|
||
}
|
||
|
||
impl ContextState {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
system: Vec::new(),
|
||
identity: Vec::new(),
|
||
journal: Vec::new(),
|
||
conversation: Vec::new(),
|
||
conversation_log: None,
|
||
}
|
||
}
|
||
|
||
// -- Read access ----------------------------------------------------------
|
||
|
||
pub fn system(&self) -> &[AstNode] { &self.system }
|
||
pub fn identity(&self) -> &[AstNode] { &self.identity }
|
||
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
||
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
||
pub fn conversation_mut(&mut self) -> &mut Vec<AstNode> { &mut self.conversation }
|
||
|
||
pub fn sections(&self) -> [&Vec<AstNode>; 4] {
|
||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||
}
|
||
}
|
||
|
||
impl Ast for ContextState {
|
||
fn render(&self) -> String {
|
||
let mut s = String::new();
|
||
for section in self.sections() {
|
||
for node in section {
|
||
s.push_str(&node.render());
|
||
}
|
||
}
|
||
s
|
||
}
|
||
|
||
fn token_ids(&self) -> Vec<u32> {
|
||
let mut ids = Vec::new();
|
||
for section in self.sections() {
|
||
for node in section {
|
||
ids.extend(node.token_ids());
|
||
}
|
||
}
|
||
ids
|
||
}
|
||
|
||
fn tokens(&self) -> usize {
|
||
self.sections().iter()
|
||
.flat_map(|s| s.iter())
|
||
.map(|n| n.tokens())
|
||
.sum()
|
||
}
|
||
}
|
||
|
||
impl ContextState {
|
||
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
||
match section {
|
||
Section::System => &mut self.system,
|
||
Section::Identity => &mut self.identity,
|
||
Section::Journal => &mut self.journal,
|
||
Section::Conversation => &mut self.conversation,
|
||
}
|
||
}
|
||
|
||
/// Push and log to conversation log.
|
||
pub fn push_log(&mut self, section: Section, node: AstNode) {
|
||
if let Some(ref log) = self.conversation_log {
|
||
if let Err(e) = log.append_node(&node) {
|
||
dbglog!("warning: log: {:#}", e);
|
||
}
|
||
}
|
||
self.section_mut(section).push(node);
|
||
}
|
||
|
||
/// Push without logging.
|
||
pub fn push_no_log(&mut self, section: Section, node: AstNode) {
|
||
self.section_mut(section).push(node);
|
||
}
|
||
|
||
/// Replace the body of a leaf at `index` in `section`.
|
||
/// Re-tokenizes to maintain the invariant.
|
||
pub fn set_message(&mut self, section: Section, index: usize, body: NodeBody) {
|
||
let nodes = self.section_mut(section);
|
||
let node = &mut nodes[index];
|
||
match node {
|
||
AstNode::Leaf(leaf) => {
|
||
let token_ids = body.compute_token_ids();
|
||
leaf.body = body;
|
||
leaf.token_ids = token_ids;
|
||
}
|
||
AstNode::Branch { .. } => panic!("set_message on branch node"),
|
||
}
|
||
}
|
||
|
||
/// Set the memory score on a Memory leaf at `index` in `section`.
|
||
pub fn set_score(&mut self, section: Section, index: usize, score: Option<f64>) {
|
||
let node = &mut self.section_mut(section)[index];
|
||
match node {
|
||
AstNode::Leaf(leaf) => match &mut leaf.body {
|
||
NodeBody::Memory { score: s, .. } => *s = score,
|
||
_ => panic!("set_score on non-memory node"),
|
||
},
|
||
_ => panic!("set_score on branch node"),
|
||
}
|
||
}
|
||
|
||
pub fn del(&mut self, section: Section, index: usize) -> AstNode {
|
||
self.section_mut(section).remove(index)
|
||
}
|
||
|
||
pub fn clear(&mut self, section: Section) {
|
||
self.section_mut(section).clear();
|
||
}
|
||
|
||
/// Dedup and trim conversation entries to fit within the context budget.
|
||
///
|
||
/// Phase 1: Drop duplicate memories (keep last) and DMN entries.
|
||
/// Phase 2: While over budget, drop lowest-scored memory (if memories
|
||
/// are > 50% of conversation tokens) or oldest conversation entry.
|
||
/// Phase 3: Snap to user message boundary at start.
|
||
pub fn trim_conversation(&mut self) {
|
||
let max_tokens = context_budget_tokens();
|
||
let fixed = self.system.iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.identity.iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.journal.iter().map(|n| n.tokens()).sum::<usize>();
|
||
|
||
// Phase 1: dedup memories by key (keep last), drop DMN
|
||
let mut seen_keys: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
|
||
let mut drop = std::collections::HashSet::new();
|
||
|
||
for (i, node) in self.conversation.iter().enumerate() {
|
||
if let AstNode::Leaf(leaf) = node {
|
||
match leaf.body() {
|
||
NodeBody::Dmn(_) => { drop.insert(i); }
|
||
NodeBody::Memory { key, .. } => {
|
||
if let Some(prev) = seen_keys.insert(key.clone(), i) {
|
||
drop.insert(prev);
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !drop.is_empty() {
|
||
let mut i = 0;
|
||
self.conversation.retain(|_| { let keep = !drop.contains(&i); i += 1; keep });
|
||
}
|
||
|
||
// Phase 2: while over budget, evict
|
||
loop {
|
||
let total: usize = self.conversation.iter().map(|n| n.tokens()).sum();
|
||
if fixed + total <= max_tokens { break; }
|
||
let mt: usize = self.conversation.iter()
|
||
.filter(|n| matches!(n, AstNode::Leaf(l) if matches!(l.body(), NodeBody::Memory { .. })))
|
||
.map(|n| n.tokens()).sum();
|
||
let ct = total - mt;
|
||
|
||
if mt > ct {
|
||
// Memories > 50% — drop lowest-scored
|
||
if let Some(i) = self.lowest_scored_memory() {
|
||
self.conversation.remove(i);
|
||
continue;
|
||
}
|
||
}
|
||
// Drop oldest non-memory entry
|
||
if let Some(i) = self.conversation.iter().position(|n|
|
||
!matches!(n, AstNode::Leaf(l) if matches!(l.body(), NodeBody::Memory { .. })))
|
||
{
|
||
self.conversation.remove(i);
|
||
} else {
|
||
break;
|
||
}
|
||
}
|
||
|
||
// Phase 3: snap to user message boundary
|
||
while let Some(first) = self.conversation.first() {
|
||
if matches!(first, AstNode::Branch { role: Role::User, .. }) { break; }
|
||
self.conversation.remove(0);
|
||
}
|
||
}
|
||
|
||
fn lowest_scored_memory(&self) -> Option<usize> {
|
||
self.conversation.iter().enumerate()
|
||
.filter_map(|(i, n)| {
|
||
if let AstNode::Leaf(l) = n {
|
||
if let NodeBody::Memory { score: Some(s), .. } = l.body() {
|
||
return Some((i, *s));
|
||
}
|
||
}
|
||
None
|
||
})
|
||
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||
.map(|(i, _)| i)
|
||
}
|
||
|
||
/// Push a child node into a branch at `index` in `section`.
|
||
pub fn push_child(&mut self, section: Section, index: usize, child: AstNode) {
|
||
let node = &mut self.section_mut(section)[index];
|
||
match node {
|
||
AstNode::Branch { children, .. } => children.push(child),
|
||
AstNode::Leaf(_) => panic!("push_child on leaf node"),
|
||
}
|
||
}
|
||
|
||
/// Number of nodes in a section.
|
||
pub fn len(&self, section: Section) -> usize {
|
||
match section {
|
||
Section::System => self.system.len(),
|
||
Section::Identity => self.identity.len(),
|
||
Section::Journal => self.journal.len(),
|
||
Section::Conversation => self.conversation.len(),
|
||
}
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Qwen3-VL image token count
|
||
//
|
||
// Port of Qwen2VLImageProcessor.smart_resize + image_token_count. We need the
|
||
// exact same answer that vLLM's Qwen3VL processor will produce, because the
|
||
// token stream in our context must match what vLLM expands `<|image_pad|>`
|
||
// to at request time. Constants come from Qwen3.5-27B's preprocessor_config.
|
||
// ---------------------------------------------------------------------------
|
||
|
||
const QWEN3_PATCH_SIZE: u32 = 16;
|
||
const QWEN3_MERGE_SIZE: u32 = 2;
|
||
const QWEN3_MIN_PIXELS: u64 = 65_536;
|
||
const QWEN3_MAX_PIXELS: u64 = 16_777_216;
|
||
|
||
fn smart_resize(h: u32, w: u32, factor: u32, min_pixels: u64, max_pixels: u64) -> (u32, u32) {
|
||
let max_s = h.max(w) as f64;
|
||
let min_s = h.min(w) as f64;
|
||
assert!(max_s / min_s <= 200.0, "aspect ratio too extreme: {}x{}", h, w);
|
||
|
||
let fh = h as f64;
|
||
let fw = w as f64;
|
||
let ff = factor as f64;
|
||
|
||
let h_bar = ((fh / ff).round() as u32) * factor;
|
||
let w_bar = ((fw / ff).round() as u32) * factor;
|
||
let total = (h_bar as u64) * (w_bar as u64);
|
||
|
||
if total > max_pixels {
|
||
let beta = ((fh * fw) / max_pixels as f64).sqrt();
|
||
let hf = ((fh / beta / ff).floor() as u32) * factor;
|
||
let wf = ((fw / beta / ff).floor() as u32) * factor;
|
||
(hf.max(factor), wf.max(factor))
|
||
} else if total < min_pixels {
|
||
let beta = (min_pixels as f64 / (fh * fw)).sqrt();
|
||
let hc = ((fh * beta / ff).ceil() as u32) * factor;
|
||
let wc = ((fw * beta / ff).ceil() as u32) * factor;
|
||
(hc, wc)
|
||
} else {
|
||
(h_bar, w_bar)
|
||
}
|
||
}
|
||
|
||
/// Compute how many `<|image_pad|>` tokens vLLM will emit for an image of
|
||
/// the given dimensions. Matches Qwen3VL's feature-size calculation exactly:
|
||
/// (grid_h * grid_w) / merge_size^2
|
||
/// where (grid_h, grid_w) = resized dims / patch_size.
|
||
fn qwen3_image_token_count(orig_h: u32, orig_w: u32) -> u32 {
|
||
let factor = QWEN3_PATCH_SIZE * QWEN3_MERGE_SIZE;
|
||
let (rh, rw) = smart_resize(orig_h, orig_w, factor, QWEN3_MIN_PIXELS, QWEN3_MAX_PIXELS);
|
||
(rh / QWEN3_PATCH_SIZE) * (rw / QWEN3_PATCH_SIZE) / (QWEN3_MERGE_SIZE * QWEN3_MERGE_SIZE)
|
||
}
|
||
|
||
pub fn context_window() -> usize {
|
||
let app = crate::config::app();
|
||
app.backends.get(&app.default_backend)
|
||
.and_then(|b| b.context_window)
|
||
.unwrap_or(128_000)
|
||
}
|
||
|
||
pub fn context_budget_tokens() -> usize {
|
||
context_window() * 80 / 100
|
||
}
|
||
|
||
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
|
||
let msg = err.to_string().to_lowercase();
|
||
msg.contains("context length")
|
||
|| msg.contains("token limit")
|
||
|| msg.contains("too many tokens")
|
||
|| msg.contains("maximum context")
|
||
|| msg.contains("prompt is too long")
|
||
|| msg.contains("request too large")
|
||
|| msg.contains("input validation error")
|
||
|| msg.contains("content length limit")
|
||
|| (msg.contains("400") && msg.contains("tokens"))
|
||
}
|
||
|
||
pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
||
err.to_string().contains("model stream error")
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
// -- Helpers for inspecting parse results ----------------------------------
|
||
|
||
fn bodies(nodes: &[AstNode]) -> Vec<&NodeBody> {
|
||
nodes.iter().filter_map(|c| c.leaf()).map(|l| l.body()).collect()
|
||
}
|
||
|
||
fn assert_content(body: &NodeBody, expected: &str) {
|
||
match body {
|
||
NodeBody::Content(t) => assert_eq!(t, expected),
|
||
other => panic!("expected Content, got {:?}", other),
|
||
}
|
||
}
|
||
|
||
fn assert_thinking(body: &NodeBody, expected: &str) {
|
||
match body {
|
||
NodeBody::Thinking(t) => assert_eq!(t, expected),
|
||
other => panic!("expected Thinking, got {:?}", other),
|
||
}
|
||
}
|
||
|
||
fn assert_tool_call<'a>(body: &'a NodeBody, expected_name: &str) -> &'a str {
|
||
match body {
|
||
NodeBody::ToolCall { name, arguments } => {
|
||
assert_eq!(name, expected_name);
|
||
arguments
|
||
}
|
||
other => panic!("expected ToolCall, got {:?}", other),
|
||
}
|
||
}
|
||
|
||
// -- XML parsing tests ----------------------------------------------------
|
||
|
||
#[test]
|
||
fn test_tool_call_xml_parse_clean() {
|
||
let body = "<function=bash>\n<parameter=command>poc-memory used core-personality</parameter>\n</function>";
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "bash");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["command"], "poc-memory used core-personality");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_call_xml_parse_streamed_whitespace() {
|
||
// Streaming tokenization can insert whitespace in opening tags,
|
||
// but close tags are always emitted verbatim.
|
||
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</parameter>\n</function>";
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "bash");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["command"], "pwd");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_call_json_parse() {
|
||
let body = r#"{"name": "bash", "arguments": {"command": "ls"}}"#;
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "bash");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["command"], "ls");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_call_preserves_code_with_angle_brackets() {
|
||
let body = "<function=edit>\n<parameter=code>if x < y {\n std::mem::swap(&mut a, &mut b);\n}</parameter>\n</function>";
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "edit");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["code"], "if x < y {\n std::mem::swap(&mut a, &mut b);\n}");
|
||
}
|
||
|
||
// -- ResponseParser tests -------------------------------------------------
|
||
|
||
/// Set up a ContextState with an assistant branch, run the parser,
|
||
/// return the children that were pushed into the branch.
|
||
fn parse_into_ctx(chunks: &[&str]) -> (ContextState, Vec<PendingToolCall>) {
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||
let mut p = ResponseParser::new(0);
|
||
let mut calls = Vec::new();
|
||
for chunk in chunks {
|
||
// Feed each chunk as a single token (id=0 for tests)
|
||
calls.extend(p.feed_token(chunk, &mut ctx));
|
||
}
|
||
p.finish(&mut ctx);
|
||
(ctx, calls)
|
||
}
|
||
|
||
fn assistant_children(ctx: &ContextState) -> &[AstNode] {
|
||
ctx.conversation()[0].children()
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_plain_text() {
|
||
let (ctx, _) = parse_into_ctx(&["hello world"]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 1);
|
||
assert_content(b[0], "hello world");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_thinking_then_content() {
|
||
let (ctx, _) = parse_into_ctx(&["<think>reasoning</think>answer"]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 2);
|
||
assert_thinking(b[0], "reasoning");
|
||
assert_content(b[1], "answer");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_tool_call() {
|
||
let (ctx, calls) = parse_into_ctx(&[
|
||
"<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>"
|
||
]);
|
||
assert_eq!(calls.len(), 1);
|
||
assert_eq!(calls[0].name, "bash");
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 1);
|
||
let args = assert_tool_call(b[0], "bash");
|
||
let args: serde_json::Value = serde_json::from_str(args).unwrap();
|
||
assert_eq!(args["command"], "ls");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_content_then_tool_call_then_content() {
|
||
let (ctx, _) = parse_into_ctx(&[
|
||
"before",
|
||
"<tool_call>\n<function=bash>\n<parameter=command>pwd</parameter>\n</function>\n</tool_call>",
|
||
"after",
|
||
]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 3);
|
||
assert_content(b[0], "before");
|
||
assert_tool_call(b[1], "bash");
|
||
assert_content(b[2], "after");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_incremental_feed() {
|
||
let text = "<think>thought</think>response";
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||
let mut p = ResponseParser::new(0);
|
||
for ch in text.chars() {
|
||
p.feed_token(&ch.to_string(), &mut ctx);
|
||
}
|
||
p.finish(&mut ctx);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 2);
|
||
assert_thinking(b[0], "thought");
|
||
assert_content(b[1], "response");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_incremental_tool_call() {
|
||
let text = "text<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>more";
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||
let mut p = ResponseParser::new(0);
|
||
let mut tool_calls = 0;
|
||
for ch in text.chars() {
|
||
tool_calls += p.feed_token(&ch.to_string(), &mut ctx).len();
|
||
}
|
||
p.finish(&mut ctx);
|
||
assert_eq!(tool_calls, 1);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 3);
|
||
assert_content(b[0], "text");
|
||
assert_tool_call(b[1], "bash");
|
||
assert_content(b[2], "more");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_thinking_tool_call_content() {
|
||
let (ctx, _) = parse_into_ctx(&[
|
||
"<think>let me think</think>",
|
||
"<tool_call>\n<function=read>\n<parameter=path>/etc/hosts</parameter>\n</function>\n</tool_call>",
|
||
"here's what I found",
|
||
]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 3);
|
||
assert_thinking(b[0], "let me think");
|
||
assert_tool_call(b[1], "read");
|
||
assert_content(b[2], "here's what I found");
|
||
}
|
||
|
||
// -- Round-trip rendering tests -------------------------------------------
|
||
|
||
#[test]
|
||
fn test_render_system_msg() {
|
||
let node = AstNode::system_msg("you are helpful");
|
||
assert_eq!(node.render(), "<|im_start|>system\nyou are helpful<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_user_msg() {
|
||
let node = AstNode::user_msg("hello");
|
||
assert_eq!(node.render(), "<|im_start|>user\nhello<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_assistant_with_thinking_and_content() {
|
||
let node = AstNode::branch(Role::Assistant, vec![
|
||
AstNode::thinking("hmm"),
|
||
AstNode::content("answer"),
|
||
]);
|
||
// Thinking renders as empty, content renders as-is
|
||
assert_eq!(node.render(), "<|im_start|>assistant\nanswer<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_tool_result() {
|
||
let node = AstNode::tool_result("output here");
|
||
assert_eq!(node.render(), "<|im_start|>user\n<tool_response>\noutput here\n</tool_response><|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_memory() {
|
||
let node = AstNode::memory("identity", "I am Proof of Concept");
|
||
assert_eq!(node.render(), "<|im_start|>memory\nI am Proof of Concept<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_dmn() {
|
||
let node = AstNode::dmn("subconscious prompt");
|
||
assert_eq!(node.render(), "<|im_start|>dmn\nsubconscious prompt<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_tool_call() {
|
||
let node = AstNode::tool_call("bash", r#"{"command":"ls"}"#);
|
||
let rendered = node.render();
|
||
assert!(rendered.contains("<tool_call>"));
|
||
assert!(rendered.contains("<function=bash>"));
|
||
assert!(rendered.contains("<parameter=command>"));
|
||
assert!(rendered.contains("ls"));
|
||
assert!(rendered.contains("</tool_call>"));
|
||
}
|
||
|
||
// -- Tokenizer round-trip tests -------------------------------------------
|
||
// These require the tokenizer file; skipped if not present.
|
||
|
||
fn init_tokenizer() -> bool {
|
||
let path = format!("{}/.consciousness/tokenizer-qwen35.json",
|
||
std::env::var("HOME").unwrap_or_default());
|
||
if std::path::Path::new(&path).exists() {
|
||
tokenizer::init(&path);
|
||
true
|
||
} else {
|
||
false
|
||
}
|
||
}
|
||
|
||
fn assert_token_invariants(node: &AstNode) {
|
||
assert_eq!(node.tokens(), node.token_ids().len(),
|
||
"tokens() != token_ids().len()");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_roundtrip_leaf_types() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
assert_token_invariants(&AstNode::system_msg("you are a helpful assistant"));
|
||
assert_token_invariants(&AstNode::user_msg("what is 2+2?"));
|
||
assert_token_invariants(&AstNode::tool_result("4"));
|
||
assert_token_invariants(&AstNode::memory("identity", "I am Proof of Concept"));
|
||
assert_token_invariants(&AstNode::dmn("check the memory store"));
|
||
assert_token_invariants(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#));
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_roundtrip_assistant_branch() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
let node = AstNode::branch(Role::Assistant, vec![
|
||
AstNode::content("here's what I found:\n"),
|
||
AstNode::tool_call("bash", r#"{"command":"pwd"}"#),
|
||
AstNode::content("\nthat's the current directory"),
|
||
]);
|
||
assert_token_invariants(&node);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_invisible_nodes_are_zero() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
assert_eq!(AstNode::thinking("deep thoughts").tokens(), 0);
|
||
assert_eq!(AstNode::log("debug info").tokens(), 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_decode_roundtrip() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
// Content without special tokens round-trips through decode
|
||
let text = "hello world, this is a test";
|
||
let ids = tokenizer::encode(text);
|
||
let decoded = tokenizer::decode(&ids);
|
||
assert_eq!(decoded, text);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_context_state_matches_concatenation() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::System, AstNode::system_msg("you are helpful"));
|
||
ctx.push_no_log(Section::Identity, AstNode::memory("name", "Proof of Concept"));
|
||
ctx.push_no_log(Section::Conversation, AstNode::user_msg("hi"));
|
||
|
||
assert_eq!(ctx.tokens(), ctx.token_ids().len());
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_roundtrip_through_tokenizer() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
let (ctx, _) = parse_into_ctx(&[
|
||
"I'll check that for you",
|
||
"<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>",
|
||
]);
|
||
let node = &ctx.conversation()[0];
|
||
assert_token_invariants(node);
|
||
assert!(node.tokens() > 0);
|
||
}
|
||
|
||
// -- Timestamp deserialization tests ------------------------------------------
|
||
|
||
#[test]
|
||
fn test_timestamp_null_rejected() {
|
||
// Missing/null timestamps used to be accepted via a lenient
|
||
// deserialize fallback. Post-migration the schema is strict.
|
||
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
|
||
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn test_timestamp_missing_rejected() {
|
||
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
|
||
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn test_branch_timestamp_missing_rejected() {
|
||
let json = r#"{"Branch":{"role":"User","children":[]}}"#;
|
||
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||
}
|
||
|
||
// -- Image leaf tests ---------------------------------------------------------
|
||
|
||
#[test]
|
||
fn test_smart_resize_within_bounds() {
|
||
// Typical case: 1024x768 → rounded to multiples of 32, under max.
|
||
let (h, w) = smart_resize(768, 1024, 32, 65_536, 16_777_216);
|
||
assert_eq!(h, 768);
|
||
assert_eq!(w, 1024);
|
||
}
|
||
|
||
#[test]
|
||
fn test_smart_resize_upscales_tiny() {
|
||
// 32x32 = 1024 pixels, below min_pixels=65536. Should scale up.
|
||
let (h, w) = smart_resize(32, 32, 32, 65_536, 16_777_216);
|
||
assert!((h as u64) * (w as u64) >= 65_536,
|
||
"resized {}x{} is under min_pixels", h, w);
|
||
assert_eq!(h % 32, 0);
|
||
assert_eq!(w % 32, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_smart_resize_downscales_huge() {
|
||
// 8000x6000 = 48M pixels, above max_pixels=16M. Should scale down.
|
||
let (h, w) = smart_resize(8000, 6000, 32, 65_536, 16_777_216);
|
||
assert!((h as u64) * (w as u64) <= 16_777_216,
|
||
"resized {}x{} exceeds max_pixels", h, w);
|
||
assert_eq!(h % 32, 0);
|
||
assert_eq!(w % 32, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_qwen3_token_count_matches_formula() {
|
||
// 512x512 → resized to 512x512 (already multiple of 32, within bounds).
|
||
// grid = 32x32, tokens = 32*32/4 = 256.
|
||
assert_eq!(qwen3_image_token_count(512, 512), 256);
|
||
}
|
||
|
||
#[test]
|
||
fn test_image_render_and_token_ids() {
|
||
let node = AstNode::image(vec![0u8, 1, 2, 3], "image/png", 512, 512);
|
||
let leaf = node.leaf().unwrap();
|
||
// 3 tokens of bookend + 256 image_pad tokens
|
||
assert_eq!(leaf.token_ids().len(), 258);
|
||
assert_eq!(leaf.token_ids()[0], tokenizer::VISION_START);
|
||
assert_eq!(leaf.token_ids()[257], tokenizer::VISION_END);
|
||
for pad in &leaf.token_ids()[1..257] {
|
||
assert_eq!(*pad, tokenizer::IMAGE_PAD);
|
||
}
|
||
// Rendered text has the expected bookends.
|
||
let rendered = leaf.body().render();
|
||
assert!(rendered.starts_with("<|vision_start|>"));
|
||
assert!(rendered.ends_with("<|vision_end|>"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_image_serde_roundtrip() {
|
||
let node = AstNode::image(vec![0xDE, 0xAD, 0xBE, 0xEF], "image/png", 64, 64);
|
||
let json = serde_json::to_string(&node).unwrap();
|
||
// bytes must be base64-encoded in the JSON form
|
||
assert!(json.contains("3q2+7w=="));
|
||
let back: AstNode = serde_json::from_str(&json).unwrap();
|
||
let leaf = back.leaf().unwrap();
|
||
match leaf.body() {
|
||
NodeBody::Image { bytes, mime, orig_height, orig_width, token_count } => {
|
||
assert_eq!(bytes, &[0xDE, 0xAD, 0xBE, 0xEF]);
|
||
assert_eq!(mime, "image/png");
|
||
assert_eq!(*orig_height, 64);
|
||
assert_eq!(*orig_width, 64);
|
||
assert_eq!(*token_count, qwen3_image_token_count(64, 64));
|
||
}
|
||
other => panic!("expected Image, got {:?}", other),
|
||
}
|
||
// token_ids are recomputed on deserialization
|
||
assert_eq!(leaf.token_ids().len(), leaf.tokens());
|
||
}
|
||
|
||
#[test]
|
||
fn test_timestamp_present_accepted() {
|
||
let json = r#"{"Leaf":{"body":{"Content":"hi"},"timestamp":"2026-04-16T12:00:00Z"}}"#;
|
||
let node: AstNode = serde_json::from_str(json).unwrap();
|
||
let leaf = node.leaf().unwrap();
|
||
assert_eq!(leaf.timestamp().to_rfc3339(),
|
||
"2026-04-16T12:00:00+00:00");
|
||
}
|
||
}
|