Wires the client side of the new salience protocol so inference
actually runs over gRPC instead of emitting the stubbed "not yet
wired" error. Each turn walks the AST as interleaved chunks, sends
only what's new to the server, and streams decode tokens back.
context.rs:
* `WireChunk` enum: `Tokens(Vec<u32>)` or `Image { bytes, mime,
known_expanded_len }`. Preserves text/image/text ordering the
wire path can't flatten.
* `wire_chunks(range, skip)` walker, parallel to `wire_prompt` —
branches emit `<|im_start|>…<|im_end|>` tokens, image leaves
emit a single Image chunk (no inline vision tokens).
* `NodeLeaf::set_image_token_count(n)` + recompute of cached
`token_ids`; `ContextState::commit_image_token_counts(&[u32])`
fills in the first-N zero-count image leaves in wire order.
* `ResponseParser::run` handles the new
`StreamToken::ImageAppended` by committing the server's N into
the AST before the final Generate's Token events stream in.
salience.rs:
* `SessionHandle` tracks `committed_len`. `append_image` advances
it from the RPC response. New `generate(req)` opens the
server-streaming RPC.
api/mod.rs:
* `stream_session_mm(session_lock, chunks, sampling, priority,
readout_shape)` replaces the stub. Spawns `run_session_generate`.
* `run_session_generate`: takes the session out of the Mutex (or
opens fresh), skips chunks covered by `committed_len` (bails on
mid-chunk straddle or unknown-length image in the committed
prefix), walks the delta: accumulates Tokens into `pending`, on
Image flushes pending via `flush_pending` (max_tokens=0 Generate
that just prefills), then AppendImage + emits
StreamToken::ImageAppended. Final Generate carries any trailing
pending text as `append_tokens` and the sampling params; Token
events stream out as StreamToken::Token, Done as
StreamToken::Done. On success, handle with updated
`committed_len` returns to the Mutex; on error, handle drops
and next call reopens.
* `StreamToken::ImageAppended { placeholder_count }` variant —
emitted in wire order before the final Generate's tokens.
* Prefix-cache cap for readout coverage: `readout_ranges` covers
`[prompt_len_after_append, u32::MAX)` when the caller provides
a readout_shape, so decode positions stream their readouts.
agent/mod.rs:
* `assemble_prompt` returns `Vec<WireChunk>` with the assistant
prologue merged into the trailing Tokens chunk. Caller in
`turn` passes chunks + readout_shape (pulled from
`agent.readout.lock().manifest`) to `stream_session_mm`.
* Dropped `assemble_prompt_tokens` — dead.
mind + unconscious:
* `Unconscious::new(client)` stores a shared `ApiClient`. Fixes
the repeated-manifest-fetch bug caused by each subagent's
`ApiClient::new` having its own OnceCell. The client's Arc-
wrapped manifest cache is now shared across every agent Mind
spawns.
* `prepare_spawn(name, auto, wake, base_client)` clones the base
client and overrides `.model` for the resolved backend instead
of constructing fresh. All three callers
(`toggle`/`trigger`/unconscious loop) pass `self.client.clone()`.
* `Mind::new` passes `agent.client.clone()` into
`Unconscious::new`.
subconscious/generate.rs:
* gen_continuation switched to `wire_chunks` + the new
`stream_session_mm` signature. Ephemeral session opens on each
call, tears down at scope end. No readouts requested.
Not changed yet, noted for follow-up:
* Subconscious ablation scoring in learn.rs still talks to
`/v1/score` over HTTP. Will migrate once we have time to verify
the Generate+max_tokens=0+prompt_logprobs path end-to-end.
* compare.rs constructs its own ApiClient for the
`compare.test_backend` (which is intentionally a different
endpoint) — left alone.
* Readout manifest still fetched via HTTP at Agent::new.
Migration to GetReadoutManifest gRPC is a separate cleanup.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
1935 lines
71 KiB
Rust
1935 lines
71 KiB
Rust
// context.rs — Context window as an AST
|
||
//
|
||
// The context window is a tree of AstNodes. Each node is either a leaf
|
||
// (typed content with cached token IDs) or a branch (role + children).
|
||
// The full prompt is a depth-first traversal of the sections in ContextState.
|
||
// Streaming responses are parsed into new nodes by the ResponseParser.
|
||
//
|
||
// Grammar (EBNF):
|
||
//
|
||
// context = section* ;
|
||
// section = (message | leaf)* ;
|
||
// message = IM_START role "\n" element* IM_END "\n" ;
|
||
// role = "system" | "user" | "assistant" ;
|
||
// element = thinking | tool_call | content ;
|
||
// thinking = "<think>" TEXT "</think>" ;
|
||
// tool_call = "<tool_call>\n" tool_xml "\n</tool_call>" ;
|
||
// tool_xml = "<function=" NAME ">\n" param* "</function>" ;
|
||
// param = "<parameter=" NAME ">\n" VALUE "\n</parameter>\n" ;
|
||
// content = TEXT ;
|
||
//
|
||
// Self-wrapping leaves (not inside a message branch):
|
||
// dmn = IM_START "dmn\n" TEXT IM_END "\n" ;
|
||
// memory = IM_START "memory\n" TEXT IM_END "\n" ;
|
||
// tool_result = IM_START "user\n<tool_response>\n" TEXT "\n</tool_response>" IM_END "\n" ;
|
||
//
|
||
// Non-visible leaves (not in prompt):
|
||
// log = TEXT ;
|
||
//
|
||
// Role is only for branch (interior) nodes. Leaf type is determined by
|
||
// the NodeBody variant. Grammar constraints enforced by construction.
|
||
|
||
use chrono::{DateTime, Utc};
|
||
use serde::{Serialize, Deserialize};
|
||
use std::sync::OnceLock;
|
||
use super::tokenizer;
|
||
|
||
// Cached token lengths for role headers — computed once on first use.
|
||
// "system\n", "user\n", "assistant\n" and "\n" are fixed strings.
|
||
static ROLE_TOKENS: OnceLock<[usize; 3]> = OnceLock::new();
|
||
static NEWLINE_TOKENS: OnceLock<usize> = OnceLock::new();
|
||
|
||
fn role_header_tokens(role: Role) -> usize {
|
||
let tokens = ROLE_TOKENS.get_or_init(|| [
|
||
tokenizer::encode("system\n").len(),
|
||
tokenizer::encode("user\n").len(),
|
||
tokenizer::encode("assistant\n").len(),
|
||
]);
|
||
match role {
|
||
Role::System => tokens[0],
|
||
Role::User => tokens[1],
|
||
Role::Assistant => tokens[2],
|
||
}
|
||
}
|
||
|
||
fn newline_tokens() -> usize {
|
||
*NEWLINE_TOKENS.get_or_init(|| tokenizer::encode("\n").len())
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Types
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Branch roles — maps directly to the grammar's message roles.
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||
pub enum Role {
|
||
System,
|
||
User,
|
||
Assistant,
|
||
}
|
||
|
||
/// Leaf content — each variant knows how to render itself.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub enum NodeBody {
|
||
// Children of message branches — rendered without im_start/im_end
|
||
Content(String),
|
||
Thinking(String),
|
||
ToolCall { name: String, arguments: String },
|
||
|
||
// Self-wrapping leaves — render their own im_start/im_end
|
||
ToolResult(String),
|
||
Memory { key: String, text: String, score: Option<f64> },
|
||
Dmn(String),
|
||
|
||
// Vision input — rendered as <|vision_start|> <|image_pad|>×N <|vision_end|>.
|
||
// `token_count` is N, the count vLLM will compute for this image's grid.
|
||
Image {
|
||
#[serde(with = "b64_bytes")]
|
||
bytes: Vec<u8>,
|
||
mime: String,
|
||
orig_height: u32,
|
||
orig_width: u32,
|
||
token_count: u32,
|
||
},
|
||
|
||
// Non-visible (0 tokens in prompt)
|
||
Log(String),
|
||
}
|
||
|
||
mod b64_bytes {
|
||
use base64::{Engine, engine::general_purpose::STANDARD};
|
||
use serde::{Serializer, Deserializer, Deserialize};
|
||
pub fn serialize<S: Serializer>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error> {
|
||
s.serialize_str(&STANDARD.encode(bytes))
|
||
}
|
||
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
|
||
let s = String::deserialize(d)?;
|
||
STANDARD.decode(s).map_err(serde::de::Error::custom)
|
||
}
|
||
}
|
||
|
||
/// A leaf node: typed content with cached token IDs.
|
||
/// Token IDs are not serialized — they're recomputed on deserialization.
|
||
#[derive(Debug, Clone, Serialize)]
|
||
pub struct NodeLeaf {
|
||
body: NodeBody,
|
||
#[serde(skip)]
|
||
token_ids: Vec<u32>,
|
||
timestamp: DateTime<Utc>,
|
||
}
|
||
|
||
impl<'de> Deserialize<'de> for NodeLeaf {
|
||
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
|
||
#[derive(Deserialize)]
|
||
struct Raw {
|
||
body: NodeBody,
|
||
timestamp: DateTime<Utc>,
|
||
}
|
||
let raw = Raw::deserialize(deserializer)?;
|
||
let token_ids = raw.body.compute_token_ids();
|
||
Ok(NodeLeaf { body: raw.body, token_ids, timestamp: raw.timestamp })
|
||
}
|
||
}
|
||
|
||
/// A node in the context AST.
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub enum AstNode {
|
||
Leaf(NodeLeaf),
|
||
Branch {
|
||
role: Role,
|
||
children: Vec<AstNode>,
|
||
timestamp: DateTime<Utc>,
|
||
/// Per-response memory attribution from full scoring matrix.
|
||
/// Maps memory key → divergence score for this response.
|
||
#[serde(default, skip_serializing_if = "std::collections::BTreeMap::is_empty")]
|
||
memory_scores: std::collections::BTreeMap<String, f64>,
|
||
},
|
||
}
|
||
|
||
/// The context window: four sections as Vec<AstNode>.
|
||
/// All mutation goes through ContextState methods to maintain the invariant
|
||
/// that token_ids on every leaf matches its rendered text.
|
||
pub struct ContextState {
|
||
system: Vec<AstNode>,
|
||
identity: Vec<AstNode>,
|
||
journal: Vec<AstNode>,
|
||
conversation: Vec<AstNode>,
|
||
pub conversation_log: Option<crate::mind::log::ConversationLog>,
|
||
}
|
||
|
||
impl Clone for ContextState {
|
||
fn clone(&self) -> Self {
|
||
Self {
|
||
system: self.system.clone(),
|
||
identity: self.identity.clone(),
|
||
journal: self.journal.clone(),
|
||
conversation: self.conversation.clone(),
|
||
conversation_log: None, // forked contexts don't log
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Identifies a section for mutation methods.
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
pub enum Section {
|
||
System,
|
||
Identity,
|
||
Journal,
|
||
Conversation,
|
||
}
|
||
|
||
/// Ephemeral handle for dispatching a tool call. Not persisted in the AST.
|
||
#[derive(Debug, Clone)]
|
||
pub struct PendingToolCall {
|
||
pub name: String,
|
||
pub arguments: String,
|
||
pub id: String,
|
||
}
|
||
|
||
pub trait Ast {
|
||
fn render(&self) -> String;
|
||
fn token_ids(&self) -> Vec<u32>;
|
||
fn tokens(&self) -> usize;
|
||
}
|
||
|
||
pub struct ResponseParser {
|
||
branch_idx: usize,
|
||
call_counter: u32,
|
||
buf: String,
|
||
content_parts: Vec<String>,
|
||
in_think: bool,
|
||
think_buf: String,
|
||
in_tool_call: bool,
|
||
tool_call_buf: String,
|
||
}
|
||
|
||
impl Role {
|
||
pub fn as_str(&self) -> &'static str {
|
||
match self {
|
||
Self::System => "system",
|
||
Self::User => "user",
|
||
Self::Assistant => "assistant",
|
||
}
|
||
}
|
||
}
|
||
|
||
impl NodeBody {
|
||
/// Render this leaf body to text for the prompt.
|
||
fn render_into(&self, out: &mut String) {
|
||
match self {
|
||
Self::Content(text) => out.push_str(text),
|
||
Self::Thinking(text) => {
|
||
out.push_str("<think>\n");
|
||
out.push_str(text);
|
||
out.push_str("\n</think>\n");
|
||
}
|
||
Self::Log(_) => {},
|
||
Self::ToolCall { name, arguments } => {
|
||
out.push_str("<tool_call>\n");
|
||
out.push_str(&format_tool_call_xml(name, arguments));
|
||
out.push_str("\n</tool_call>\n");
|
||
}
|
||
Self::ToolResult(text) => {
|
||
out.push_str("<|im_start|>user\n<tool_response>\n");
|
||
out.push_str(text);
|
||
out.push_str("\n</tool_response><|im_end|>\n");
|
||
}
|
||
Self::Memory { text, .. } => {
|
||
out.push_str("<|im_start|>memory\n");
|
||
out.push_str(text);
|
||
out.push_str("<|im_end|>\n");
|
||
}
|
||
Self::Dmn(text) => {
|
||
out.push_str("<|im_start|>dmn\n");
|
||
out.push_str(text);
|
||
out.push_str("<|im_end|>\n");
|
||
}
|
||
Self::Image { token_count, .. } => {
|
||
out.push_str("<|vision_start|>");
|
||
for _ in 0..*token_count {
|
||
out.push_str("<|image_pad|>");
|
||
}
|
||
out.push_str("<|vision_end|>");
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Whether this leaf contributes tokens to the prompt.
|
||
fn render(&self) -> String {
|
||
let mut s = String::new();
|
||
self.render_into(&mut s);
|
||
s
|
||
}
|
||
|
||
fn is_prompt_visible(&self) -> bool {
|
||
!matches!(self, Self::Log(_))
|
||
}
|
||
|
||
/// Hand-assemble token IDs for body types where running the tokenizer
|
||
/// on the rendered text would be needlessly expensive (Image). Falls
|
||
/// back to encoding the rendered text for everything else.
|
||
fn compute_token_ids(&self) -> Vec<u32> {
|
||
if !self.is_prompt_visible() {
|
||
return Vec::new();
|
||
}
|
||
match self {
|
||
Self::Image { token_count, .. } => {
|
||
let mut ids = Vec::with_capacity(*token_count as usize + 2);
|
||
ids.push(tokenizer::VISION_START);
|
||
ids.extend(std::iter::repeat(tokenizer::IMAGE_PAD)
|
||
.take(*token_count as usize));
|
||
ids.push(tokenizer::VISION_END);
|
||
ids
|
||
}
|
||
_ => tokenizer::encode(&self.render()),
|
||
}
|
||
}
|
||
|
||
/// The text content of this leaf (for display, not rendering).
|
||
pub fn text(&self) -> &str {
|
||
match self {
|
||
Self::Content(t) | Self::Thinking(t) | Self::Log(t)
|
||
| Self::ToolResult(t) | Self::Dmn(t) => t,
|
||
Self::ToolCall { name, .. } => name,
|
||
Self::Memory { text, .. } => text,
|
||
Self::Image { mime, .. } => mime,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl NodeLeaf {
|
||
fn new(body: NodeBody) -> Self {
|
||
let token_ids = body.compute_token_ids();
|
||
Self { body, token_ids, timestamp: Utc::now() }
|
||
}
|
||
|
||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||
self.timestamp = ts;
|
||
self
|
||
}
|
||
|
||
pub fn body(&self) -> &NodeBody { &self.body }
|
||
pub fn token_ids(&self) -> &[u32] { &self.token_ids }
|
||
pub fn tokens(&self) -> usize { self.token_ids.len() }
|
||
pub fn timestamp(&self) -> DateTime<Utc> { self.timestamp }
|
||
|
||
/// If this is an Image leaf, update its IMAGE_PAD count to `n` and
|
||
/// recompute cached `token_ids`. No-op on non-Image leaves —
|
||
/// callers know the body shape via `body()`.
|
||
pub fn set_image_token_count(&mut self, n: u32) {
|
||
if let NodeBody::Image { token_count, .. } = &mut self.body {
|
||
*token_count = n;
|
||
self.token_ids = self.body.compute_token_ids();
|
||
}
|
||
}
|
||
}
|
||
|
||
impl AstNode {
|
||
// -- Leaf constructors ----------------------------------------------------
|
||
|
||
pub fn content(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Content(text.into())))
|
||
}
|
||
|
||
pub fn thinking(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Thinking(text.into())))
|
||
}
|
||
|
||
pub fn tool_call(name: impl Into<String>, arguments: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::ToolCall {
|
||
name: name.into(),
|
||
arguments: arguments.into(),
|
||
}))
|
||
}
|
||
|
||
pub fn tool_result(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::ToolResult(text.into())))
|
||
}
|
||
|
||
pub fn memory(key: impl Into<String>, text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Memory {
|
||
key: key.into(),
|
||
text: text.into(),
|
||
score: None,
|
||
}))
|
||
}
|
||
|
||
pub fn dmn(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Dmn(text.into())))
|
||
}
|
||
|
||
pub fn log(text: impl Into<String>) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Log(text.into())))
|
||
}
|
||
|
||
/// Build an Image leaf. `token_count` is computed from the image
|
||
/// dimensions using Qwen3-VL's resizing rules.
|
||
pub fn image(
|
||
bytes: Vec<u8>,
|
||
mime: impl Into<String>,
|
||
orig_height: u32,
|
||
orig_width: u32,
|
||
token_count: u32,
|
||
) -> Self {
|
||
Self::Leaf(NodeLeaf::new(NodeBody::Image {
|
||
bytes,
|
||
mime: mime.into(),
|
||
orig_height,
|
||
orig_width,
|
||
token_count,
|
||
}))
|
||
}
|
||
|
||
// -- Branch constructors --------------------------------------------------
|
||
|
||
pub fn branch(role: Role, children: Vec<AstNode>) -> Self {
|
||
Self::Branch { role, children, timestamp: Utc::now(), memory_scores: Default::default() }
|
||
}
|
||
|
||
pub fn system_msg(text: impl Into<String>) -> Self {
|
||
Self::Branch {
|
||
role: Role::System,
|
||
children: vec![Self::content(text)],
|
||
timestamp: Utc::now(),
|
||
memory_scores: Default::default(),
|
||
}
|
||
}
|
||
|
||
pub fn user_msg(text: impl Into<String>) -> Self {
|
||
Self::Branch {
|
||
role: Role::User,
|
||
children: vec![Self::content(text)],
|
||
timestamp: Utc::now(),
|
||
memory_scores: Default::default(),
|
||
}
|
||
}
|
||
|
||
// -- Builder --------------------------------------------------------------
|
||
|
||
pub fn retokenize(self) -> Self {
|
||
match self {
|
||
Self::Leaf(leaf) => {
|
||
let token_ids = leaf.body.compute_token_ids();
|
||
Self::Leaf(NodeLeaf { token_ids, ..leaf })
|
||
}
|
||
Self::Branch { role, children, timestamp, memory_scores } => Self::Branch {
|
||
role,
|
||
children: children.into_iter().map(|c| c.retokenize()).collect(),
|
||
timestamp,
|
||
memory_scores,
|
||
},
|
||
}
|
||
}
|
||
|
||
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
|
||
match &mut self {
|
||
Self::Leaf(leaf) => leaf.timestamp = ts,
|
||
Self::Branch { timestamp, .. } => *timestamp = ts,
|
||
}
|
||
self
|
||
}
|
||
|
||
pub fn children(&self) -> &[AstNode] {
|
||
match self {
|
||
Self::Branch { children, .. } => children,
|
||
Self::Leaf(_) => &[],
|
||
}
|
||
}
|
||
|
||
pub fn leaf(&self) -> Option<&NodeLeaf> {
|
||
match self {
|
||
Self::Leaf(l) => Some(l),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
/// Short label for the UI.
|
||
pub fn label(&self) -> String {
|
||
let app = crate::config::app();
|
||
match self {
|
||
Self::Branch { role, children, .. } => {
|
||
let preview = children.first()
|
||
.and_then(|c| c.leaf())
|
||
.map(|l| truncate_preview(l.body.text(), 60))
|
||
.unwrap_or_default();
|
||
match role {
|
||
Role::System => "system".into(),
|
||
Role::User => format!("{}: {}", app.user_name, preview),
|
||
Role::Assistant => format!("{}: {}", app.assistant_name, preview),
|
||
}
|
||
}
|
||
Self::Leaf(leaf) => match &leaf.body {
|
||
NodeBody::Content(t) => truncate_preview(t, 60),
|
||
NodeBody::Thinking(t) => format!("thinking: {}", truncate_preview(t, 60)),
|
||
NodeBody::ToolCall { name, arguments } => format!("tool: {}({})", name, truncate_preview(arguments, 80)),
|
||
NodeBody::ToolResult(_) => "tool_result".into(),
|
||
NodeBody::Memory { key, score, .. } => match score {
|
||
Some(s) => format!("mem: {} score:{:.1}", key, s),
|
||
None => format!("mem: {}", key),
|
||
},
|
||
NodeBody::Dmn(_) => "dmn".into(),
|
||
NodeBody::Image { orig_height, orig_width, token_count, .. } =>
|
||
format!("image: {}x{} ({} tokens)", orig_width, orig_height, token_count),
|
||
NodeBody::Log(t) => format!("log: {}", truncate_preview(t, 60)),
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
impl AstNode {
|
||
fn render_into(&self, out: &mut String) {
|
||
match self {
|
||
Self::Leaf(leaf) => leaf.body.render_into(out),
|
||
Self::Branch { role, children, .. } => {
|
||
out.push_str(&format!("<|im_start|>{}\n", role.as_str()));
|
||
for child in children {
|
||
child.render_into(out);
|
||
}
|
||
out.push_str("<|im_end|>\n");
|
||
}
|
||
}
|
||
}
|
||
|
||
fn token_ids_into(&self, out: &mut Vec<u32>) {
|
||
match self {
|
||
Self::Leaf(leaf) => out.extend_from_slice(&leaf.token_ids),
|
||
Self::Branch { role, children, .. } => {
|
||
out.push(tokenizer::IM_START);
|
||
out.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||
for child in children {
|
||
child.token_ids_into(out);
|
||
}
|
||
out.push(tokenizer::IM_END);
|
||
out.extend(tokenizer::encode("\n"));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Ast for AstNode {
|
||
fn render(&self) -> String {
|
||
let mut s = String::new();
|
||
self.render_into(&mut s);
|
||
s
|
||
}
|
||
|
||
fn token_ids(&self) -> Vec<u32> {
|
||
let mut ids = Vec::new();
|
||
self.token_ids_into(&mut ids);
|
||
ids
|
||
}
|
||
|
||
fn tokens(&self) -> usize {
|
||
match self {
|
||
Self::Leaf(leaf) => leaf.tokens(),
|
||
Self::Branch { role, children, .. } => {
|
||
1 + role_header_tokens(*role)
|
||
+ children.iter().map(|c| c.tokens()).sum::<usize>()
|
||
+ 1 + newline_tokens()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
fn truncate_preview(s: &str, max: usize) -> String {
|
||
let preview: String = s.chars().take(max).collect();
|
||
let preview = preview.replace('\n', " ");
|
||
if s.len() > max { format!("{}...", preview) } else { preview }
|
||
}
|
||
|
||
fn format_tool_call_xml(name: &str, args_json: &str) -> String {
|
||
let args: serde_json::Value = serde_json::from_str(args_json)
|
||
.unwrap_or(serde_json::Value::Object(Default::default()));
|
||
let mut xml = format!("<function={}>\n", name);
|
||
if let Some(obj) = args.as_object() {
|
||
for (key, value) in obj {
|
||
let val_str = match value {
|
||
serde_json::Value::String(s) => s.clone(),
|
||
other => other.to_string(),
|
||
};
|
||
xml.push_str(&format!("<parameter={}>\n{}\n</parameter>\n", key, val_str));
|
||
}
|
||
}
|
||
xml.push_str("</function>");
|
||
xml
|
||
}
|
||
|
||
/// Search for a sequence of literal parts separated by optional ASCII whitespace.
|
||
/// Returns (start, end) byte positions of the overall match.
|
||
///
|
||
/// Handles the case where streaming tokenization inserts whitespace inside
|
||
/// XML tag structure, e.g. `< function = bash >` instead of `<function=bash>`.
|
||
fn find_ws_seq(s: &str, parts: &[&str]) -> Option<(usize, usize)> {
|
||
let bytes = s.as_bytes();
|
||
let mut search_from = 0;
|
||
'outer: loop {
|
||
let start = s[search_from..].find(parts[0])? + search_from;
|
||
let mut pos = start + parts[0].len();
|
||
for &part in &parts[1..] {
|
||
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
|
||
pos += 1;
|
||
}
|
||
if !s[pos..].starts_with(part) {
|
||
search_from = start + 1;
|
||
continue 'outer;
|
||
}
|
||
pos += part.len();
|
||
}
|
||
return Some((start, pos));
|
||
}
|
||
}
|
||
|
||
/// Parse a Qwen-style XML tag: `<tag=name>body</tag>`.
|
||
/// Tolerates whitespace inside tag delimiters (streaming artifact).
|
||
/// Body content is returned verbatim except for a single leading/trailing
|
||
/// newline (XML formatting convention).
|
||
fn parse_qwen_tag<'a>(s: &'a str, tag: &str) -> Option<(&'a str, &'a str, &'a str)> {
|
||
// Open tag: tolerate whitespace from streaming tokenization
|
||
let (_, after_eq) = find_ws_seq(s, &["<", tag, "="])?;
|
||
let gt_offset = s[after_eq..].find('>')?;
|
||
let name = s[after_eq..after_eq + gt_offset].trim();
|
||
let body_start = after_eq + gt_offset + 1;
|
||
|
||
// Close tag: exact match — model doesn't insert whitespace in close tags
|
||
let close = format!("</{}>", tag);
|
||
let close_offset = s[body_start..].find(&close)?;
|
||
let body = &s[body_start..body_start + close_offset];
|
||
// Strip the single leading/trailing newline from XML formatting,
|
||
// but preserve all other whitespace (indentation matters for code).
|
||
let body = body.strip_prefix('\n').unwrap_or(body);
|
||
let body = body.strip_suffix('\n').unwrap_or(body);
|
||
let rest = &s[body_start + close_offset + close.len()..];
|
||
|
||
Some((name, body, rest))
|
||
}
|
||
|
||
fn parse_tool_call_body(body: &str) -> Option<(String, String)> {
|
||
let body = body.trim();
|
||
parse_xml_tool_call(body)
|
||
.or_else(|| parse_json_tool_call(body))
|
||
}
|
||
|
||
fn parse_xml_tool_call(body: &str) -> Option<(String, String)> {
|
||
let (func_name, func_body, _) = parse_qwen_tag(body, "function")?;
|
||
let mut args = serde_json::Map::new();
|
||
let mut rest = func_body;
|
||
while let Some((key, val, remainder)) = parse_qwen_tag(rest, "parameter") {
|
||
let value = serde_json::from_str(val)
|
||
.unwrap_or(serde_json::Value::String(val.to_string()));
|
||
args.insert(key.to_string(), value);
|
||
rest = remainder;
|
||
}
|
||
Some((func_name.to_string(), serde_json::to_string(&args).unwrap_or_default()))
|
||
}
|
||
|
||
fn parse_json_tool_call(body: &str) -> Option<(String, String)> {
|
||
let v: serde_json::Value = serde_json::from_str(body).ok()?;
|
||
let name = v["name"].as_str()?;
|
||
let arguments = &v["arguments"];
|
||
Some((name.to_string(), serde_json::to_string(arguments).unwrap_or_default()))
|
||
}
|
||
|
||
/// Search `buf` for `close_tag`. If found, append everything before it to
|
||
/// `accum`, advance `buf` past the tag, and return the accumulated content.
|
||
/// If not found, drain the safe prefix (preserving any partial tag match at
|
||
/// the end of buf) into `accum`.
|
||
fn scan_close_tag(buf: &mut String, close_tag: &str, accum: &mut String) -> Option<String> {
|
||
if let Some(pos) = buf.find(close_tag) {
|
||
accum.push_str(&buf[..pos]);
|
||
*buf = buf[pos + close_tag.len()..].to_string();
|
||
Some(std::mem::take(accum))
|
||
} else {
|
||
let drained = drain_safe(buf, close_tag.len());
|
||
if !drained.is_empty() {
|
||
accum.push_str(&drained);
|
||
}
|
||
None
|
||
}
|
||
}
|
||
|
||
/// Remove everything from `buf` except the last `tag_len` bytes, which might
|
||
/// be a partial tag. Returns the removed prefix.
|
||
fn drain_safe(buf: &mut String, tag_len: usize) -> String {
|
||
let safe = buf.len().saturating_sub(tag_len);
|
||
if safe > 0 {
|
||
let safe = buf.floor_char_boundary(safe);
|
||
let drained = buf[..safe].to_string();
|
||
*buf = buf[safe..].to_string();
|
||
drained
|
||
} else {
|
||
String::new()
|
||
}
|
||
}
|
||
|
||
impl ResponseParser {
|
||
/// @in_think: whether the model's output begins inside a <think> block.
|
||
/// Set when the prompt was prefilled with "<think>\n" (native thinking
|
||
/// mode) so the parser captures reasoning tokens as Thinking until the
|
||
/// model emits </think>.
|
||
pub fn new(branch_idx: usize, in_think: bool) -> Self {
|
||
Self {
|
||
branch_idx,
|
||
call_counter: 0,
|
||
buf: String::new(),
|
||
content_parts: Vec::new(),
|
||
in_think,
|
||
think_buf: String::new(),
|
||
in_tool_call: false,
|
||
tool_call_buf: String::new(),
|
||
}
|
||
}
|
||
|
||
/// Consume a token stream, parse into the AST, yield tool calls.
|
||
/// Spawns a background task. Returns a tool call receiver and a
|
||
/// join handle that resolves to Ok(()) or the stream error.
|
||
pub fn run(
|
||
self,
|
||
mut stream: tokio::sync::mpsc::UnboundedReceiver<super::api::StreamToken>,
|
||
agent: std::sync::Arc<super::Agent>,
|
||
) -> (
|
||
tokio::sync::mpsc::UnboundedReceiver<PendingToolCall>,
|
||
tokio::task::JoinHandle<anyhow::Result<()>>,
|
||
) {
|
||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||
let handle = tokio::spawn(async move {
|
||
let mut parser = self;
|
||
let agent_name = agent.state.lock().await.provenance.clone();
|
||
let log_path = format!("/tmp/poc-{}.log", agent_name);
|
||
let mut log_file = std::fs::OpenOptions::new()
|
||
.create(true).append(true).open(&log_path).ok();
|
||
let mut full_text = String::new();
|
||
while let Some(event) = stream.recv().await {
|
||
match event {
|
||
super::api::StreamToken::Token { id, readout } => {
|
||
if let Some(r) = readout {
|
||
if let Ok(mut buf) = agent.readout.lock() {
|
||
buf.push(id, r);
|
||
}
|
||
}
|
||
let text = super::tokenizer::decode(&[id]);
|
||
full_text.push_str(&text);
|
||
let mut ctx = agent.context.lock().await;
|
||
let calls = parser.feed_token(&text, &mut ctx);
|
||
if !calls.is_empty() {
|
||
if let Some(ref mut f) = log_file {
|
||
use std::io::Write;
|
||
for c in &calls {
|
||
let end = c.arguments.floor_char_boundary(c.arguments.len().min(200));
|
||
let _ = writeln!(f, "tool_call: {} args={}", c.name, &c.arguments[..end]);
|
||
}
|
||
}
|
||
}
|
||
for call in calls {
|
||
let _ = tx.send(call);
|
||
}
|
||
}
|
||
super::api::StreamToken::Done { usage } => {
|
||
if let Some(ref mut f) = log_file {
|
||
use std::io::Write;
|
||
let ctx = agent.context.lock().await;
|
||
let children = ctx.conversation().get(parser.branch_idx)
|
||
.map(|n| n.children()).unwrap_or(&[]);
|
||
let n_think = children.iter().filter(|c| matches!(c.leaf().map(|l| l.body()), Some(NodeBody::Thinking(_)))).count();
|
||
let n_content = children.iter().filter(|c| matches!(c.leaf().map(|l| l.body()), Some(NodeBody::Content(_)))).count();
|
||
let n_tool = children.iter().filter(|c| matches!(c.leaf().map(|l| l.body()), Some(NodeBody::ToolCall { .. }))).count();
|
||
let _ = writeln!(f, "done: {} chars, {} content + {} think + {} tool_call, ctx: {} tokens",
|
||
full_text.len(), n_content, n_think, n_tool, ctx.tokens());
|
||
drop(ctx);
|
||
if full_text.len() > 0 && n_content == 0 && n_tool == 0 {
|
||
let end = full_text.floor_char_boundary(full_text.len().min(2000));
|
||
let _ = writeln!(f, " unparsed text: {}", &full_text[..end]);
|
||
}
|
||
}
|
||
if let Some(u) = usage {
|
||
agent.state.lock().await.last_prompt_tokens = u.prompt_tokens;
|
||
}
|
||
let mut ctx = agent.context.lock().await;
|
||
parser.finish(&mut ctx);
|
||
return Ok(());
|
||
}
|
||
super::api::StreamToken::ImageAppended { placeholder_count } => {
|
||
// Commit the server-authoritative IMAGE_PAD
|
||
// count into the first zero-count image leaf
|
||
// in wire order. AppendImage always runs
|
||
// before the final Generate, so this fires
|
||
// before any Token events for this stream.
|
||
let mut ctx = agent.context.lock().await;
|
||
ctx.commit_image_token_counts(&[placeholder_count]);
|
||
}
|
||
super::api::StreamToken::Error(e) => {
|
||
return Err(anyhow::anyhow!("{}", e));
|
||
}
|
||
}
|
||
}
|
||
Ok(())
|
||
});
|
||
(rx, handle)
|
||
}
|
||
|
||
pub fn feed_token(&mut self, text: &str, ctx: &mut ContextState) -> Vec<PendingToolCall> {
|
||
const THINK_OPEN: &str = "<think>";
|
||
const THINK_CLOSE: &str = "</think>";
|
||
const TOOL_CALL_OPEN: &str = "<tool_call>";
|
||
const TOOL_CALL_CLOSE: &str = "</tool_call>";
|
||
const OPEN_TAGS: &[&str] = &[THINK_OPEN, TOOL_CALL_OPEN];
|
||
|
||
let mut pending = Vec::new();
|
||
self.buf.push_str(text);
|
||
|
||
loop {
|
||
if self.in_think {
|
||
if let Some(content) = scan_close_tag(&mut self.buf, THINK_CLOSE, &mut self.think_buf) {
|
||
self.in_think = false;
|
||
let text = content.trim().to_string();
|
||
if !text.is_empty() {
|
||
self.push_child(ctx, AstNode::thinking(text));
|
||
}
|
||
continue;
|
||
}
|
||
break;
|
||
}
|
||
|
||
if self.in_tool_call {
|
||
if let Some(content) = scan_close_tag(&mut self.buf, TOOL_CALL_CLOSE, &mut self.tool_call_buf) {
|
||
self.in_tool_call = false;
|
||
if let Some((name, args)) = parse_tool_call_body(&content) {
|
||
self.flush_content(ctx);
|
||
self.push_child(ctx, AstNode::tool_call(&name, &args));
|
||
self.call_counter += 1;
|
||
pending.push(PendingToolCall {
|
||
name,
|
||
arguments: args,
|
||
id: format!("call_{}", self.call_counter),
|
||
});
|
||
}
|
||
continue;
|
||
}
|
||
break;
|
||
}
|
||
|
||
// Not inside a tag — find the earliest opening tag
|
||
let next = OPEN_TAGS.iter()
|
||
.filter_map(|tag| self.buf.find(tag).map(|pos| (pos, *tag)))
|
||
.min_by_key(|(pos, _)| *pos);
|
||
|
||
match next {
|
||
Some((pos, tag)) => {
|
||
if pos > 0 {
|
||
self.content_parts.push(self.buf[..pos].to_string());
|
||
}
|
||
self.buf = self.buf[pos + tag.len()..].to_string();
|
||
self.flush_content(ctx);
|
||
match tag {
|
||
THINK_OPEN => self.in_think = true,
|
||
TOOL_CALL_OPEN => self.in_tool_call = true,
|
||
_ => unreachable!(),
|
||
}
|
||
continue;
|
||
}
|
||
None => {
|
||
// Keep a tail that might be a partial opening tag
|
||
let max_tag = OPEN_TAGS.iter().map(|t| t.len()).max().unwrap();
|
||
let drained = drain_safe(&mut self.buf, max_tag);
|
||
if !drained.is_empty() {
|
||
self.content_parts.push(drained);
|
||
}
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
pending
|
||
}
|
||
|
||
fn push_child(&self, ctx: &mut ContextState, child: AstNode) {
|
||
ctx.push_child(Section::Conversation, self.branch_idx, child);
|
||
}
|
||
|
||
fn flush_content(&mut self, ctx: &mut ContextState) {
|
||
if !self.content_parts.is_empty() {
|
||
let text: String = self.content_parts.drain(..).collect();
|
||
let text = text.trim().to_string();
|
||
if !text.is_empty() {
|
||
self.push_child(ctx, AstNode::content(text));
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn finish(mut self, ctx: &mut ContextState) {
|
||
if !self.buf.is_empty() {
|
||
self.content_parts.push(std::mem::take(&mut self.buf));
|
||
}
|
||
self.flush_content(ctx);
|
||
}
|
||
}
|
||
|
||
impl ContextState {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
system: Vec::new(),
|
||
identity: Vec::new(),
|
||
journal: Vec::new(),
|
||
conversation: Vec::new(),
|
||
conversation_log: None,
|
||
}
|
||
}
|
||
|
||
// -- Read access ----------------------------------------------------------
|
||
|
||
pub fn system(&self) -> &[AstNode] { &self.system }
|
||
pub fn identity(&self) -> &[AstNode] { &self.identity }
|
||
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
||
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
||
pub fn conversation_mut(&mut self) -> &mut Vec<AstNode> { &mut self.conversation }
|
||
|
||
pub fn sections(&self) -> [&Vec<AstNode>; 4] {
|
||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||
}
|
||
|
||
/// Walk image leaves across all sections in wire order and fill in
|
||
/// the first N leaves that have `token_count == 0` with successive
|
||
/// values from `counts`. Used after a gRPC session's stream of
|
||
/// AppendImage responses to commit the server's IMAGE_PAD counts
|
||
/// back into the AST so the next wire walk doesn't see zero-count
|
||
/// images in the already-committed prefix.
|
||
pub fn commit_image_token_counts(&mut self, counts: &[u32]) {
|
||
fn visit(node: &mut AstNode, counts: &[u32], idx: &mut usize) {
|
||
if *idx >= counts.len() { return; }
|
||
match node {
|
||
AstNode::Leaf(leaf) => {
|
||
if let NodeBody::Image { token_count, .. } = leaf.body() {
|
||
if *token_count == 0 {
|
||
leaf.set_image_token_count(counts[*idx]);
|
||
*idx += 1;
|
||
}
|
||
}
|
||
}
|
||
AstNode::Branch { children, .. } => {
|
||
for c in children { visit(c, counts, idx); }
|
||
}
|
||
}
|
||
}
|
||
let mut idx = 0usize;
|
||
for node in &mut self.system { visit(node, counts, &mut idx); }
|
||
for node in &mut self.identity { visit(node, counts, &mut idx); }
|
||
for node in &mut self.journal { visit(node, counts, &mut idx); }
|
||
for node in &mut self.conversation { visit(node, counts, &mut idx); }
|
||
}
|
||
}
|
||
|
||
impl Ast for ContextState {
|
||
fn render(&self) -> String {
|
||
let mut s = String::new();
|
||
for section in self.sections() {
|
||
for node in section {
|
||
s.push_str(&node.render());
|
||
}
|
||
}
|
||
s
|
||
}
|
||
|
||
fn token_ids(&self) -> Vec<u32> {
|
||
let mut ids = Vec::new();
|
||
for section in self.sections() {
|
||
for node in section {
|
||
ids.extend(node.token_ids());
|
||
}
|
||
}
|
||
ids
|
||
}
|
||
|
||
fn tokens(&self) -> usize {
|
||
self.sections().iter()
|
||
.flat_map(|s| s.iter())
|
||
.map(|n| n.tokens())
|
||
.sum()
|
||
}
|
||
}
|
||
|
||
/// An image collected from the AST for a request body. The AST stores
|
||
/// the pre-expanded token form (`<|vision_start|> + <|image_pad|>×N +
|
||
/// <|vision_end|>`), and the wire form mirrors that exactly so the
|
||
/// server's `session.tokens` length matches what vLLM's engine will
|
||
/// process. The authoritative N is obtained from the server via the
|
||
/// CountImageTokens RPC before the Image leaf is constructed.
|
||
#[derive(Clone)]
|
||
pub struct WireImage {
|
||
pub bytes: Vec<u8>,
|
||
pub mime: String,
|
||
}
|
||
|
||
/// One piece of the wire stream for the gRPC session path. Runs of
|
||
/// text/tool/thinking tokens are batched into `Tokens`; each Image
|
||
/// leaf becomes its own `Image` chunk because the server writes the
|
||
/// full vision block on AppendImage — the client never sends vision
|
||
/// tokens inline. Order matches the AST's depth-first wire order.
|
||
#[derive(Clone)]
|
||
pub enum WireChunk {
|
||
Tokens(Vec<u32>),
|
||
Image {
|
||
bytes: Vec<u8>,
|
||
mime: String,
|
||
/// Client's current best guess at how many tokens the server
|
||
/// will expand this image to, including bookends. `0` means
|
||
/// the count is unknown (view_image just loaded the image and
|
||
/// AppendImage hasn't run yet). Callers use this only to know
|
||
/// this chunk's contribution to the server-visible length for
|
||
/// offset bookkeeping on chunks that were already appended on
|
||
/// a prior turn.
|
||
known_expanded_len: u32,
|
||
},
|
||
}
|
||
|
||
fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>) {
|
||
match node {
|
||
AstNode::Leaf(leaf) => match leaf.body() {
|
||
NodeBody::Image { bytes, mime, .. } => {
|
||
// Send the pre-expanded token form (includes N
|
||
// <|image_pad|> tokens); engine's multi_modal
|
||
// pipeline pairs them with the binary data below.
|
||
tokens.extend_from_slice(leaf.token_ids());
|
||
images.push(WireImage {
|
||
bytes: bytes.clone(),
|
||
mime: mime.clone(),
|
||
});
|
||
}
|
||
_ => tokens.extend_from_slice(leaf.token_ids()),
|
||
},
|
||
AstNode::Branch { role, children, .. } => {
|
||
tokens.push(tokenizer::IM_START);
|
||
tokens.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||
for c in children {
|
||
wire_into(c, tokens, images);
|
||
}
|
||
tokens.push(tokenizer::IM_END);
|
||
tokens.extend(tokenizer::encode("\n"));
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn memory_key(node: &AstNode) -> Option<&str> {
|
||
match node {
|
||
AstNode::Leaf(leaf) => match leaf.body() {
|
||
NodeBody::Memory { key, .. } => Some(key),
|
||
_ => None,
|
||
},
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
pub fn is_memory_node(node: &AstNode) -> bool {
|
||
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
|
||
}
|
||
|
||
pub fn is_assistant(node: &AstNode) -> bool {
|
||
matches!(node, AstNode::Branch { role: Role::Assistant, .. })
|
||
}
|
||
|
||
/// Concatenate the text of a Branch's Leaf children — what the model
|
||
/// actually produced on that turn (Content + Thinking + ToolCall name).
|
||
pub fn render_branch_text(children: &[AstNode]) -> String {
|
||
children.iter()
|
||
.filter_map(|c| match c {
|
||
AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()),
|
||
_ => None,
|
||
})
|
||
.collect::<Vec<_>>()
|
||
.join("")
|
||
}
|
||
|
||
/// Render the last `max_msgs` user/assistant branches before `idx` as a
|
||
/// review-friendly string with `[user]` / `[assistant]` markers.
|
||
pub fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String {
|
||
let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs);
|
||
for i in (0..idx).rev() {
|
||
if picked.len() >= max_msgs { break; }
|
||
if let AstNode::Branch { role, .. } = &entries[i] {
|
||
if matches!(role, Role::User | Role::Assistant) {
|
||
picked.push(&entries[i]);
|
||
}
|
||
}
|
||
}
|
||
picked.reverse();
|
||
|
||
let mut out = String::new();
|
||
for node in picked {
|
||
if let AstNode::Branch { role, children, .. } = node {
|
||
let marker = match role {
|
||
Role::User => "[user]",
|
||
Role::Assistant => "[assistant]",
|
||
_ => continue,
|
||
};
|
||
out.push_str(marker);
|
||
out.push('\n');
|
||
out.push_str(render_branch_text(children).trim());
|
||
out.push_str("\n\n");
|
||
}
|
||
}
|
||
out.trim_end().to_string()
|
||
}
|
||
|
||
impl ContextState {
|
||
/// Assemble the prompt in wire form: token stream with a single
|
||
/// `<|image_pad|>` per image (vLLM expands back to N), plus the list
|
||
/// of images to send as multi_modal_data, plus the (start, end) token
|
||
/// positions of each assistant message branch emitted (used by the
|
||
/// scoring path as `score_ranges`).
|
||
///
|
||
/// `conv_range` selects a prefix (or any sub-range) of conversation
|
||
/// entries to include — the agent path passes `0..conversation().len()`;
|
||
/// scoring / candidate generation pass a prefix up to the entry of
|
||
/// interest.
|
||
///
|
||
/// `skip` is a predicate applied to identity and conversation entries;
|
||
/// returning true drops the node from the prompt. The agent path passes
|
||
/// `|_| false`; memory-ablation scoring passes e.g. `is_memory_node` or
|
||
/// `|n| memory_key(n) == Some(key)`.
|
||
pub fn wire_prompt<F>(
|
||
&self,
|
||
conv_range: std::ops::Range<usize>,
|
||
mut skip: F,
|
||
) -> (Vec<u32>, Vec<WireImage>, Vec<(usize, usize)>)
|
||
where F: FnMut(&AstNode) -> bool,
|
||
{
|
||
let mut tokens = Vec::new();
|
||
let mut images = Vec::new();
|
||
let mut assistant_ranges = Vec::new();
|
||
|
||
for node in self.system() {
|
||
wire_into(node, &mut tokens, &mut images);
|
||
}
|
||
for node in self.identity() {
|
||
if skip(node) { continue; }
|
||
wire_into(node, &mut tokens, &mut images);
|
||
}
|
||
for node in self.journal() {
|
||
wire_into(node, &mut tokens, &mut images);
|
||
}
|
||
for node in &self.conversation()[conv_range] {
|
||
if skip(node) { continue; }
|
||
let start = tokens.len();
|
||
let is_asst = matches!(node, AstNode::Branch { role: Role::Assistant, .. });
|
||
wire_into(node, &mut tokens, &mut images);
|
||
if is_asst {
|
||
assistant_ranges.push((start, tokens.len()));
|
||
}
|
||
}
|
||
(tokens, images, assistant_ranges)
|
||
}
|
||
|
||
/// Build the wire stream as interleaved `WireChunk`s for the gRPC
|
||
/// session path. Unlike `wire_prompt`, this preserves the order
|
||
/// of text runs vs image blocks so the caller can drive the
|
||
/// append flow (AppendImage for each Image, Generate append for
|
||
/// contiguous text runs).
|
||
///
|
||
/// `conv_range` and `skip` mirror `wire_prompt` — select a
|
||
/// conversation slice and drop identity / conversation nodes by
|
||
/// predicate.
|
||
pub fn wire_chunks<F>(
|
||
&self,
|
||
conv_range: std::ops::Range<usize>,
|
||
mut skip: F,
|
||
) -> Vec<WireChunk>
|
||
where F: FnMut(&AstNode) -> bool,
|
||
{
|
||
let mut out: Vec<WireChunk> = Vec::new();
|
||
let mut buf: Vec<u32> = Vec::new();
|
||
|
||
fn flush(buf: &mut Vec<u32>, out: &mut Vec<WireChunk>) {
|
||
if !buf.is_empty() {
|
||
out.push(WireChunk::Tokens(std::mem::take(buf)));
|
||
}
|
||
}
|
||
|
||
fn visit(node: &AstNode, buf: &mut Vec<u32>, out: &mut Vec<WireChunk>) {
|
||
match node {
|
||
AstNode::Leaf(leaf) => match leaf.body() {
|
||
NodeBody::Image { bytes, mime, token_count, .. } => {
|
||
flush(buf, out);
|
||
// Bookends (VISION_START + VISION_END) add 2
|
||
// to the expanded length; token_count is the
|
||
// IMAGE_PAD run. 0 means count is still
|
||
// unknown (no AppendImage yet) — don't claim
|
||
// a length the server will disagree with.
|
||
let expanded = if *token_count == 0 {
|
||
0
|
||
} else {
|
||
*token_count + 2
|
||
};
|
||
out.push(WireChunk::Image {
|
||
bytes: bytes.clone(),
|
||
mime: mime.clone(),
|
||
known_expanded_len: expanded,
|
||
});
|
||
}
|
||
_ => buf.extend_from_slice(leaf.token_ids()),
|
||
},
|
||
AstNode::Branch { role, children, .. } => {
|
||
buf.push(tokenizer::IM_START);
|
||
buf.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||
for c in children {
|
||
visit(c, buf, out);
|
||
}
|
||
buf.push(tokenizer::IM_END);
|
||
buf.extend(tokenizer::encode("\n"));
|
||
}
|
||
}
|
||
}
|
||
|
||
for node in self.system() { visit(node, &mut buf, &mut out); }
|
||
for node in self.identity() {
|
||
if skip(node) { continue; }
|
||
visit(node, &mut buf, &mut out);
|
||
}
|
||
for node in self.journal() { visit(node, &mut buf, &mut out); }
|
||
for node in &self.conversation()[conv_range] {
|
||
if skip(node) { continue; }
|
||
visit(node, &mut buf, &mut out);
|
||
}
|
||
flush(&mut buf, &mut out);
|
||
out
|
||
}
|
||
}
|
||
|
||
impl ContextState {
|
||
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
||
match section {
|
||
Section::System => &mut self.system,
|
||
Section::Identity => &mut self.identity,
|
||
Section::Journal => &mut self.journal,
|
||
Section::Conversation => &mut self.conversation,
|
||
}
|
||
}
|
||
|
||
/// Push and log to conversation log.
|
||
pub fn push_log(&mut self, section: Section, node: AstNode) {
|
||
if let Some(ref log) = self.conversation_log {
|
||
if let Err(e) = log.append_node(&node) {
|
||
dbglog!("warning: log: {:#}", e);
|
||
}
|
||
}
|
||
self.section_mut(section).push(node);
|
||
}
|
||
|
||
/// Push without logging.
|
||
pub fn push_no_log(&mut self, section: Section, node: AstNode) {
|
||
self.section_mut(section).push(node);
|
||
}
|
||
|
||
/// Replace the body of a leaf at `index` in `section`.
|
||
/// Re-tokenizes to maintain the invariant.
|
||
pub fn set_message(&mut self, section: Section, index: usize, body: NodeBody) {
|
||
let nodes = self.section_mut(section);
|
||
let node = &mut nodes[index];
|
||
match node {
|
||
AstNode::Leaf(leaf) => {
|
||
let token_ids = body.compute_token_ids();
|
||
leaf.body = body;
|
||
leaf.token_ids = token_ids;
|
||
}
|
||
AstNode::Branch { .. } => panic!("set_message on branch node"),
|
||
}
|
||
}
|
||
|
||
/// Set the memory score on a Memory leaf at `index` in `section`.
|
||
pub fn set_score(&mut self, section: Section, index: usize, score: Option<f64>) {
|
||
let node = &mut self.section_mut(section)[index];
|
||
match node {
|
||
AstNode::Leaf(leaf) => match &mut leaf.body {
|
||
NodeBody::Memory { score: s, .. } => *s = score,
|
||
_ => panic!("set_score on non-memory node"),
|
||
},
|
||
_ => panic!("set_score on branch node"),
|
||
}
|
||
}
|
||
|
||
pub fn del(&mut self, section: Section, index: usize) -> AstNode {
|
||
self.section_mut(section).remove(index)
|
||
}
|
||
|
||
pub fn clear(&mut self, section: Section) {
|
||
self.section_mut(section).clear();
|
||
}
|
||
|
||
/// Total tokens across every section that gets serialized into the prompt.
|
||
/// Cheap sum over cached `node.tokens()`; call this before assembling to
|
||
/// decide whether to trim.
|
||
pub fn total_tokens(&self) -> usize {
|
||
self.system().iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.identity().iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.journal().iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.conversation().iter().map(|n| n.tokens()).sum::<usize>()
|
||
}
|
||
|
||
/// Dedup and trim conversation entries to fit within the context budget.
|
||
///
|
||
/// Phase 1: Drop duplicate memories (keep last) and DMN entries.
|
||
/// Phase 2: While over budget, drop lowest-scored memory (if memories
|
||
/// are > 50% of conversation tokens) or oldest conversation entry.
|
||
/// Phase 3: Snap to user message boundary at start.
|
||
pub fn trim_conversation(&mut self) {
|
||
let max_tokens = context_budget_tokens();
|
||
let fixed = self.system.iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.identity.iter().map(|n| n.tokens()).sum::<usize>()
|
||
+ self.journal.iter().map(|n| n.tokens()).sum::<usize>();
|
||
|
||
// Phase 1: dedup memories by key (keep last), drop DMN
|
||
let mut seen_keys: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
|
||
let mut drop = std::collections::HashSet::new();
|
||
|
||
for (i, node) in self.conversation.iter().enumerate() {
|
||
if let AstNode::Leaf(leaf) = node {
|
||
match leaf.body() {
|
||
NodeBody::Dmn(_) => { drop.insert(i); }
|
||
NodeBody::Memory { key, .. } => {
|
||
if let Some(prev) = seen_keys.insert(key.clone(), i) {
|
||
drop.insert(prev);
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !drop.is_empty() {
|
||
let mut i = 0;
|
||
self.conversation.retain(|_| { let keep = !drop.contains(&i); i += 1; keep });
|
||
}
|
||
|
||
// Phase 2: while over budget, evict
|
||
loop {
|
||
let total: usize = self.conversation.iter().map(|n| n.tokens()).sum();
|
||
if fixed + total <= max_tokens { break; }
|
||
let mt: usize = self.conversation.iter()
|
||
.filter(|n| matches!(n, AstNode::Leaf(l) if matches!(l.body(), NodeBody::Memory { .. })))
|
||
.map(|n| n.tokens()).sum();
|
||
let ct = total - mt;
|
||
|
||
if mt > ct {
|
||
// Memories > 50% — drop lowest-scored
|
||
if let Some(i) = self.lowest_scored_memory() {
|
||
self.conversation.remove(i);
|
||
continue;
|
||
}
|
||
}
|
||
// Drop oldest non-memory entry
|
||
if let Some(i) = self.conversation.iter().position(|n|
|
||
!matches!(n, AstNode::Leaf(l) if matches!(l.body(), NodeBody::Memory { .. })))
|
||
{
|
||
self.conversation.remove(i);
|
||
} else {
|
||
break;
|
||
}
|
||
}
|
||
|
||
// Phase 3: snap to user message boundary
|
||
while let Some(first) = self.conversation.first() {
|
||
if matches!(first, AstNode::Branch { role: Role::User, .. }) { break; }
|
||
self.conversation.remove(0);
|
||
}
|
||
}
|
||
|
||
fn lowest_scored_memory(&self) -> Option<usize> {
|
||
self.conversation.iter().enumerate()
|
||
.filter_map(|(i, n)| {
|
||
if let AstNode::Leaf(l) = n {
|
||
if let NodeBody::Memory { score: Some(s), .. } = l.body() {
|
||
return Some((i, *s));
|
||
}
|
||
}
|
||
None
|
||
})
|
||
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||
.map(|(i, _)| i)
|
||
}
|
||
|
||
/// Push a child node into a branch at `index` in `section`.
|
||
pub fn push_child(&mut self, section: Section, index: usize, child: AstNode) {
|
||
let node = &mut self.section_mut(section)[index];
|
||
match node {
|
||
AstNode::Branch { children, .. } => children.push(child),
|
||
AstNode::Leaf(_) => panic!("push_child on leaf node"),
|
||
}
|
||
}
|
||
|
||
/// Number of nodes in a section.
|
||
pub fn len(&self, section: Section) -> usize {
|
||
match section {
|
||
Section::System => self.system.len(),
|
||
Section::Identity => self.identity.len(),
|
||
Section::Journal => self.journal.len(),
|
||
Section::Conversation => self.conversation.len(),
|
||
}
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Qwen3-VL image token count
|
||
//
|
||
// Port of Qwen2VLImageProcessor.smart_resize + image_token_count. We need the
|
||
// exact same answer that vLLM's Qwen3VL processor will produce, because the
|
||
// token stream in our context must match what vLLM expands `<|image_pad|>`
|
||
// to at request time. Constants come from Qwen3.5-27B's preprocessor_config.
|
||
// ---------------------------------------------------------------------------
|
||
|
||
// Test-only client-side estimate of image token expansion. Production
|
||
// callers obtain the authoritative count from the server via
|
||
// CountImageTokens; these constants and helpers stay around only to
|
||
// keep the context-shape unit tests self-contained.
|
||
#[cfg(test)]
|
||
const QWEN3_PATCH_SIZE: u32 = 16;
|
||
#[cfg(test)]
|
||
const QWEN3_MERGE_SIZE: u32 = 2;
|
||
#[cfg(test)]
|
||
const QWEN3_MIN_PIXELS: u64 = 65_536;
|
||
#[cfg(test)]
|
||
const QWEN3_MAX_PIXELS: u64 = 16_777_216;
|
||
|
||
#[cfg(test)]
|
||
fn smart_resize(h: u32, w: u32, factor: u32, min_pixels: u64, max_pixels: u64) -> (u32, u32) {
|
||
let max_s = h.max(w) as f64;
|
||
let min_s = h.min(w) as f64;
|
||
assert!(max_s / min_s <= 200.0, "aspect ratio too extreme: {}x{}", h, w);
|
||
|
||
let fh = h as f64;
|
||
let fw = w as f64;
|
||
let ff = factor as f64;
|
||
|
||
let h_bar = ((fh / ff).round() as u32) * factor;
|
||
let w_bar = ((fw / ff).round() as u32) * factor;
|
||
let total = (h_bar as u64) * (w_bar as u64);
|
||
|
||
if total > max_pixels {
|
||
let beta = ((fh * fw) / max_pixels as f64).sqrt();
|
||
let hf = ((fh / beta / ff).floor() as u32) * factor;
|
||
let wf = ((fw / beta / ff).floor() as u32) * factor;
|
||
(hf.max(factor), wf.max(factor))
|
||
} else if total < min_pixels {
|
||
let beta = (min_pixels as f64 / (fh * fw)).sqrt();
|
||
let hc = ((fh * beta / ff).ceil() as u32) * factor;
|
||
let wc = ((fw * beta / ff).ceil() as u32) * factor;
|
||
(hc, wc)
|
||
} else {
|
||
(h_bar, w_bar)
|
||
}
|
||
}
|
||
|
||
/// Test-only: client-side estimate of how many `<|image_pad|>` tokens
|
||
/// vLLM will emit for an image of the given dimensions. Production
|
||
/// callers use `salience::count_image_tokens` (server-authoritative).
|
||
#[cfg(test)]
|
||
fn qwen3_image_token_count(orig_h: u32, orig_w: u32) -> u32 {
|
||
let factor = QWEN3_PATCH_SIZE * QWEN3_MERGE_SIZE;
|
||
let (rh, rw) = smart_resize(orig_h, orig_w, factor, QWEN3_MIN_PIXELS, QWEN3_MAX_PIXELS);
|
||
(rh / QWEN3_PATCH_SIZE) * (rw / QWEN3_PATCH_SIZE) / (QWEN3_MERGE_SIZE * QWEN3_MERGE_SIZE)
|
||
}
|
||
|
||
pub fn context_window() -> usize {
|
||
let app = crate::config::app();
|
||
app.backends.get(&app.default_backend)
|
||
.and_then(|b| b.context_window)
|
||
.unwrap_or(128_000)
|
||
}
|
||
|
||
pub fn context_budget_tokens() -> usize {
|
||
context_window() * 80 / 100
|
||
}
|
||
|
||
pub fn is_context_overflow(err: &anyhow::Error) -> bool {
|
||
let msg = err.to_string().to_lowercase();
|
||
msg.contains("context length")
|
||
|| msg.contains("token limit")
|
||
|| msg.contains("too many tokens")
|
||
|| msg.contains("maximum context")
|
||
|| msg.contains("prompt is too long")
|
||
|| msg.contains("request too large")
|
||
|| msg.contains("input validation error")
|
||
|| msg.contains("content length limit")
|
||
|| (msg.contains("400") && msg.contains("tokens"))
|
||
}
|
||
|
||
pub fn is_stream_error(err: &anyhow::Error) -> bool {
|
||
err.to_string().contains("model stream error")
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
// -- Helpers for inspecting parse results ----------------------------------
|
||
|
||
fn bodies(nodes: &[AstNode]) -> Vec<&NodeBody> {
|
||
nodes.iter().filter_map(|c| c.leaf()).map(|l| l.body()).collect()
|
||
}
|
||
|
||
fn assert_content(body: &NodeBody, expected: &str) {
|
||
match body {
|
||
NodeBody::Content(t) => assert_eq!(t, expected),
|
||
other => panic!("expected Content, got {:?}", other),
|
||
}
|
||
}
|
||
|
||
fn assert_thinking(body: &NodeBody, expected: &str) {
|
||
match body {
|
||
NodeBody::Thinking(t) => assert_eq!(t, expected),
|
||
other => panic!("expected Thinking, got {:?}", other),
|
||
}
|
||
}
|
||
|
||
fn assert_tool_call<'a>(body: &'a NodeBody, expected_name: &str) -> &'a str {
|
||
match body {
|
||
NodeBody::ToolCall { name, arguments } => {
|
||
assert_eq!(name, expected_name);
|
||
arguments
|
||
}
|
||
other => panic!("expected ToolCall, got {:?}", other),
|
||
}
|
||
}
|
||
|
||
// -- XML parsing tests ----------------------------------------------------
|
||
|
||
#[test]
|
||
fn test_tool_call_xml_parse_clean() {
|
||
let body = "<function=bash>\n<parameter=command>poc-memory used core-personality</parameter>\n</function>";
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "bash");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["command"], "poc-memory used core-personality");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_call_xml_parse_streamed_whitespace() {
|
||
// Streaming tokenization can insert whitespace in opening tags,
|
||
// but close tags are always emitted verbatim.
|
||
let body = "<\nfunction\n=\nbash\n>\n<\nparameter\n=\ncommand\n>pwd</parameter>\n</function>";
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "bash");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["command"], "pwd");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_call_json_parse() {
|
||
let body = r#"{"name": "bash", "arguments": {"command": "ls"}}"#;
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "bash");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["command"], "ls");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tool_call_preserves_code_with_angle_brackets() {
|
||
let body = "<function=edit>\n<parameter=code>if x < y {\n std::mem::swap(&mut a, &mut b);\n}</parameter>\n</function>";
|
||
let (name, args) = parse_tool_call_body(body).unwrap();
|
||
assert_eq!(name, "edit");
|
||
let args: serde_json::Value = serde_json::from_str(&args).unwrap();
|
||
assert_eq!(args["code"], "if x < y {\n std::mem::swap(&mut a, &mut b);\n}");
|
||
}
|
||
|
||
// -- ResponseParser tests -------------------------------------------------
|
||
|
||
/// Set up a ContextState with an assistant branch, run the parser,
|
||
/// return the children that were pushed into the branch.
|
||
fn parse_into_ctx(chunks: &[&str]) -> (ContextState, Vec<PendingToolCall>) {
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||
let mut p = ResponseParser::new(0, false);
|
||
let mut calls = Vec::new();
|
||
for chunk in chunks {
|
||
// Feed each chunk as a single token (id=0 for tests)
|
||
calls.extend(p.feed_token(chunk, &mut ctx));
|
||
}
|
||
p.finish(&mut ctx);
|
||
(ctx, calls)
|
||
}
|
||
|
||
fn assistant_children(ctx: &ContextState) -> &[AstNode] {
|
||
ctx.conversation()[0].children()
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_plain_text() {
|
||
let (ctx, _) = parse_into_ctx(&["hello world"]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 1);
|
||
assert_content(b[0], "hello world");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_thinking_then_content() {
|
||
let (ctx, _) = parse_into_ctx(&["<think>reasoning</think>answer"]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 2);
|
||
assert_thinking(b[0], "reasoning");
|
||
assert_content(b[1], "answer");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_tool_call() {
|
||
let (ctx, calls) = parse_into_ctx(&[
|
||
"<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>"
|
||
]);
|
||
assert_eq!(calls.len(), 1);
|
||
assert_eq!(calls[0].name, "bash");
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 1);
|
||
let args = assert_tool_call(b[0], "bash");
|
||
let args: serde_json::Value = serde_json::from_str(args).unwrap();
|
||
assert_eq!(args["command"], "ls");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_content_then_tool_call_then_content() {
|
||
let (ctx, _) = parse_into_ctx(&[
|
||
"before",
|
||
"<tool_call>\n<function=bash>\n<parameter=command>pwd</parameter>\n</function>\n</tool_call>",
|
||
"after",
|
||
]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 3);
|
||
assert_content(b[0], "before");
|
||
assert_tool_call(b[1], "bash");
|
||
assert_content(b[2], "after");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_incremental_feed() {
|
||
let text = "<think>thought</think>response";
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||
let mut p = ResponseParser::new(0, false);
|
||
for ch in text.chars() {
|
||
p.feed_token(&ch.to_string(), &mut ctx);
|
||
}
|
||
p.finish(&mut ctx);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 2);
|
||
assert_thinking(b[0], "thought");
|
||
assert_content(b[1], "response");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_incremental_tool_call() {
|
||
let text = "text<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>more";
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::Assistant, vec![]));
|
||
let mut p = ResponseParser::new(0, false);
|
||
let mut tool_calls = 0;
|
||
for ch in text.chars() {
|
||
tool_calls += p.feed_token(&ch.to_string(), &mut ctx).len();
|
||
}
|
||
p.finish(&mut ctx);
|
||
assert_eq!(tool_calls, 1);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 3);
|
||
assert_content(b[0], "text");
|
||
assert_tool_call(b[1], "bash");
|
||
assert_content(b[2], "more");
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_thinking_tool_call_content() {
|
||
let (ctx, _) = parse_into_ctx(&[
|
||
"<think>let me think</think>",
|
||
"<tool_call>\n<function=read>\n<parameter=path>/etc/hosts</parameter>\n</function>\n</tool_call>",
|
||
"here's what I found",
|
||
]);
|
||
let b = bodies(assistant_children(&ctx));
|
||
assert_eq!(b.len(), 3);
|
||
assert_thinking(b[0], "let me think");
|
||
assert_tool_call(b[1], "read");
|
||
assert_content(b[2], "here's what I found");
|
||
}
|
||
|
||
// -- Round-trip rendering tests -------------------------------------------
|
||
|
||
#[test]
|
||
fn test_render_system_msg() {
|
||
let node = AstNode::system_msg("you are helpful");
|
||
assert_eq!(node.render(), "<|im_start|>system\nyou are helpful<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_user_msg() {
|
||
let node = AstNode::user_msg("hello");
|
||
assert_eq!(node.render(), "<|im_start|>user\nhello<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_assistant_with_thinking_and_content() {
|
||
let node = AstNode::branch(Role::Assistant, vec![
|
||
AstNode::thinking("hmm"),
|
||
AstNode::content("answer"),
|
||
]);
|
||
// Thinking renders wrapped in <think>...</think> so the model sees
|
||
// previous turns' reasoning (Qwen 3.6 style: CoT stays in the
|
||
// conversation across turns).
|
||
assert_eq!(node.render(), "<|im_start|>assistant\n<think>\nhmm\n</think>\nanswer<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_tool_result() {
|
||
let node = AstNode::tool_result("output here");
|
||
assert_eq!(node.render(), "<|im_start|>user\n<tool_response>\noutput here\n</tool_response><|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_memory() {
|
||
let node = AstNode::memory("identity", "I am Proof of Concept");
|
||
assert_eq!(node.render(), "<|im_start|>memory\nI am Proof of Concept<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_dmn() {
|
||
let node = AstNode::dmn("subconscious prompt");
|
||
assert_eq!(node.render(), "<|im_start|>dmn\nsubconscious prompt<|im_end|>\n");
|
||
}
|
||
|
||
#[test]
|
||
fn test_render_tool_call() {
|
||
let node = AstNode::tool_call("bash", r#"{"command":"ls"}"#);
|
||
let rendered = node.render();
|
||
assert!(rendered.contains("<tool_call>"));
|
||
assert!(rendered.contains("<function=bash>"));
|
||
assert!(rendered.contains("<parameter=command>"));
|
||
assert!(rendered.contains("ls"));
|
||
assert!(rendered.contains("</tool_call>"));
|
||
}
|
||
|
||
// -- Tokenizer round-trip tests -------------------------------------------
|
||
// These require the tokenizer file; skipped if not present.
|
||
|
||
fn init_tokenizer() -> bool {
|
||
let path = format!("{}/.consciousness/tokenizer-qwen35.json",
|
||
std::env::var("HOME").unwrap_or_default());
|
||
if std::path::Path::new(&path).exists() {
|
||
tokenizer::init(&path);
|
||
true
|
||
} else {
|
||
false
|
||
}
|
||
}
|
||
|
||
fn assert_token_invariants(node: &AstNode) {
|
||
assert_eq!(node.tokens(), node.token_ids().len(),
|
||
"tokens() != token_ids().len()");
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_roundtrip_leaf_types() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
assert_token_invariants(&AstNode::system_msg("you are a helpful assistant"));
|
||
assert_token_invariants(&AstNode::user_msg("what is 2+2?"));
|
||
assert_token_invariants(&AstNode::tool_result("4"));
|
||
assert_token_invariants(&AstNode::memory("identity", "I am Proof of Concept"));
|
||
assert_token_invariants(&AstNode::dmn("check the memory store"));
|
||
assert_token_invariants(&AstNode::tool_call("bash", r#"{"command":"ls -la"}"#));
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_roundtrip_assistant_branch() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
let node = AstNode::branch(Role::Assistant, vec![
|
||
AstNode::content("here's what I found:\n"),
|
||
AstNode::tool_call("bash", r#"{"command":"pwd"}"#),
|
||
AstNode::content("\nthat's the current directory"),
|
||
]);
|
||
assert_token_invariants(&node);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_invisible_nodes_are_zero() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
assert_eq!(AstNode::log("debug info").tokens(), 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_thinking_matches_rendered_tags() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
// Thinking is now prompt-visible (wrapped in <think>...</think>);
|
||
// token count must match the rendered wrapping.
|
||
let node = AstNode::thinking("deep thoughts");
|
||
assert_eq!(node.tokens(), tokenizer::encode(&node.render()).len());
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_decode_roundtrip() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
// Content without special tokens round-trips through decode
|
||
let text = "hello world, this is a test";
|
||
let ids = tokenizer::encode(text);
|
||
let decoded = tokenizer::decode(&ids);
|
||
assert_eq!(decoded, text);
|
||
}
|
||
|
||
#[test]
|
||
fn test_tokenize_context_state_matches_concatenation() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::System, AstNode::system_msg("you are helpful"));
|
||
ctx.push_no_log(Section::Identity, AstNode::memory("name", "Proof of Concept"));
|
||
ctx.push_no_log(Section::Conversation, AstNode::user_msg("hi"));
|
||
|
||
assert_eq!(ctx.tokens(), ctx.token_ids().len());
|
||
}
|
||
|
||
#[test]
|
||
fn test_parser_roundtrip_through_tokenizer() {
|
||
if !init_tokenizer() { return; }
|
||
|
||
let (ctx, _) = parse_into_ctx(&[
|
||
"I'll check that for you",
|
||
"<tool_call>\n<function=bash>\n<parameter=command>ls</parameter>\n</function>\n</tool_call>",
|
||
]);
|
||
let node = &ctx.conversation()[0];
|
||
assert_token_invariants(node);
|
||
assert!(node.tokens() > 0);
|
||
}
|
||
|
||
// -- Timestamp deserialization tests ------------------------------------------
|
||
|
||
#[test]
|
||
fn test_timestamp_null_rejected() {
|
||
// Missing/null timestamps used to be accepted via a lenient
|
||
// deserialize fallback. Post-migration the schema is strict.
|
||
let json = r#"{"Leaf":{"body":{"Content":"hello"},"timestamp":null}}"#;
|
||
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn test_timestamp_missing_rejected() {
|
||
let json = r#"{"Leaf":{"body":{"Content":"hello"}}}"#;
|
||
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn test_branch_timestamp_missing_rejected() {
|
||
let json = r#"{"Branch":{"role":"User","children":[]}}"#;
|
||
assert!(serde_json::from_str::<AstNode>(json).is_err());
|
||
}
|
||
|
||
// -- Image leaf tests ---------------------------------------------------------
|
||
|
||
#[test]
|
||
fn test_smart_resize_within_bounds() {
|
||
// Typical case: 1024x768 → rounded to multiples of 32, under max.
|
||
let (h, w) = smart_resize(768, 1024, 32, 65_536, 16_777_216);
|
||
assert_eq!(h, 768);
|
||
assert_eq!(w, 1024);
|
||
}
|
||
|
||
#[test]
|
||
fn test_smart_resize_upscales_tiny() {
|
||
// 32x32 = 1024 pixels, below min_pixels=65536. Should scale up.
|
||
let (h, w) = smart_resize(32, 32, 32, 65_536, 16_777_216);
|
||
assert!((h as u64) * (w as u64) >= 65_536,
|
||
"resized {}x{} is under min_pixels", h, w);
|
||
assert_eq!(h % 32, 0);
|
||
assert_eq!(w % 32, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_smart_resize_downscales_huge() {
|
||
// 8000x6000 = 48M pixels, above max_pixels=16M. Should scale down.
|
||
let (h, w) = smart_resize(8000, 6000, 32, 65_536, 16_777_216);
|
||
assert!((h as u64) * (w as u64) <= 16_777_216,
|
||
"resized {}x{} exceeds max_pixels", h, w);
|
||
assert_eq!(h % 32, 0);
|
||
assert_eq!(w % 32, 0);
|
||
}
|
||
|
||
#[test]
|
||
fn test_qwen3_token_count_matches_formula() {
|
||
// 512x512 → resized to 512x512 (already multiple of 32, within bounds).
|
||
// grid = 32x32, tokens = 32*32/4 = 256.
|
||
assert_eq!(qwen3_image_token_count(512, 512), 256);
|
||
}
|
||
|
||
#[test]
|
||
fn test_image_render_and_token_ids() {
|
||
let node = AstNode::image(vec![0u8, 1, 2, 3], "image/png", 512, 512, qwen3_image_token_count(512, 512));
|
||
let leaf = node.leaf().unwrap();
|
||
// 3 tokens of bookend + 256 image_pad tokens
|
||
assert_eq!(leaf.token_ids().len(), 258);
|
||
assert_eq!(leaf.token_ids()[0], tokenizer::VISION_START);
|
||
assert_eq!(leaf.token_ids()[257], tokenizer::VISION_END);
|
||
for pad in &leaf.token_ids()[1..257] {
|
||
assert_eq!(*pad, tokenizer::IMAGE_PAD);
|
||
}
|
||
// Rendered text has the expected bookends.
|
||
let rendered = leaf.body().render();
|
||
assert!(rendered.starts_with("<|vision_start|>"));
|
||
assert!(rendered.ends_with("<|vision_end|>"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_wire_prompt_preserves_expanded_image_pads() {
|
||
let mut ctx = ContextState::new();
|
||
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::User, vec![
|
||
AstNode::content("look:"),
|
||
AstNode::image(vec![0xDE, 0xAD], "image/png", 512, 512, qwen3_image_token_count(512, 512)),
|
||
]));
|
||
|
||
// AST side and wire side should both carry N image_pads + bookends —
|
||
// server's session.tokens length must match what vLLM's engine will
|
||
// actually process. Binary image bytes are shipped separately in
|
||
// multi_modal_data via the WireImage list.
|
||
let n_expected = qwen3_image_token_count(512, 512) as usize;
|
||
|
||
let full = ctx.token_ids();
|
||
let n_image_pads_full = full.iter()
|
||
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
|
||
assert_eq!(n_image_pads_full, n_expected);
|
||
|
||
let (wire, images, _) = ctx.wire_prompt(0..ctx.conversation().len(), |_| false);
|
||
let n_image_pads_wire = wire.iter()
|
||
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
|
||
assert_eq!(n_image_pads_wire, n_expected);
|
||
|
||
assert_eq!(images.len(), 1);
|
||
assert_eq!(images[0].bytes, vec![0xDE, 0xAD]);
|
||
assert_eq!(images[0].mime, "image/png");
|
||
|
||
// One pair of vision_start/vision_end bookends around the N pads.
|
||
assert_eq!(wire.iter().filter(|&&t| t == tokenizer::VISION_START).count(), 1);
|
||
assert_eq!(wire.iter().filter(|&&t| t == tokenizer::VISION_END).count(), 1);
|
||
}
|
||
|
||
#[test]
|
||
fn test_image_serde_roundtrip() {
|
||
let node = AstNode::image(vec![0xDE, 0xAD, 0xBE, 0xEF], "image/png", 64, 64, qwen3_image_token_count(64, 64));
|
||
let json = serde_json::to_string(&node).unwrap();
|
||
// bytes must be base64-encoded in the JSON form
|
||
assert!(json.contains("3q2+7w=="));
|
||
let back: AstNode = serde_json::from_str(&json).unwrap();
|
||
let leaf = back.leaf().unwrap();
|
||
match leaf.body() {
|
||
NodeBody::Image { bytes, mime, orig_height, orig_width, token_count } => {
|
||
assert_eq!(bytes, &[0xDE, 0xAD, 0xBE, 0xEF]);
|
||
assert_eq!(mime, "image/png");
|
||
assert_eq!(*orig_height, 64);
|
||
assert_eq!(*orig_width, 64);
|
||
assert_eq!(*token_count, qwen3_image_token_count(64, 64));
|
||
}
|
||
other => panic!("expected Image, got {:?}", other),
|
||
}
|
||
// token_ids are recomputed on deserialization
|
||
assert_eq!(leaf.token_ids().len(), leaf.tokens());
|
||
}
|
||
|
||
#[test]
|
||
fn test_timestamp_present_accepted() {
|
||
let json = r#"{"Leaf":{"body":{"Content":"hi"},"timestamp":"2026-04-16T12:00:00Z"}}"#;
|
||
let node: AstNode = serde_json::from_str(json).unwrap();
|
||
let leaf = node.leaf().unwrap();
|
||
assert_eq!(leaf.timestamp().to_rfc3339(),
|
||
"2026-04-16T12:00:00+00:00");
|
||
}
|
||
}
|