diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index 7c06fa7..649d95c 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -78,18 +78,31 @@ impl ApiClient { prompt_tokens: &[u32], sampling: SamplingParams, priority: Option, + ) -> (mpsc::UnboundedReceiver, AbortOnDrop) { + self.stream_completion_mm(prompt_tokens, &[], sampling, priority) + } + + pub(crate) fn stream_completion_mm( + &self, + prompt_tokens: &[u32], + images: &[super::context::WireImage], + sampling: SamplingParams, + priority: Option, ) -> (mpsc::UnboundedReceiver, AbortOnDrop) { let (tx, rx) = mpsc::unbounded_channel(); let client = self.client.clone(); let api_key = self.api_key.clone(); let model = self.model.clone(); let prompt_tokens = prompt_tokens.to_vec(); + let images: Vec<(Vec, String)> = images.iter() + .map(|i| (i.bytes.clone(), i.mime.clone())) + .collect(); let base_url = self.base_url.clone(); let handle = tokio::spawn(async move { let result = stream_completions( &client, &base_url, &api_key, &model, - &prompt_tokens, &tx, sampling, priority, + &prompt_tokens, &images, &tx, sampling, priority, ).await; if let Err(e) = result { let _ = tx.send(StreamToken::Error(e.to_string())); @@ -110,6 +123,7 @@ async fn stream_completions( api_key: &str, model: &str, prompt_tokens: &[u32], + images: &[(Vec, String)], tx: &mpsc::UnboundedSender, sampling: SamplingParams, priority: Option, @@ -126,6 +140,14 @@ async fn stream_completions( "skip_special_tokens": false, "stop_token_ids": [super::tokenizer::IM_END], }); + if !images.is_empty() { + use base64::Engine; + let b64 = base64::engine::general_purpose::STANDARD; + let uris: Vec = images.iter() + .map(|(bytes, mime)| format!("data:{};base64,{}", mime, b64.encode(bytes))) + .collect(); + request["multi_modal_data"] = serde_json::json!({ "image": uris }); + } if let Some(p) = priority { request["priority"] = serde_json::json!(p); } diff --git a/src/agent/context.rs b/src/agent/context.rs index 57b2c7a..0082f06 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -884,6 +884,58 @@ impl Ast for ContextState { } } +/// An image collected from the AST for a request body. The AST stores +/// the pre-expanded token form (N image_pads) for accurate budget +/// accounting; the wire form collapses each Image to a single +/// `<|image_pad|>` between vision bookends and ships the bytes +/// separately as multi_modal_data. +pub struct WireImage { + pub bytes: Vec, + pub mime: String, +} + +fn wire_into(node: &AstNode, tokens: &mut Vec, images: &mut Vec) { + match node { + AstNode::Leaf(leaf) => match leaf.body() { + NodeBody::Image { bytes, mime, .. } => { + tokens.push(tokenizer::VISION_START); + tokens.push(tokenizer::IMAGE_PAD); + tokens.push(tokenizer::VISION_END); + images.push(WireImage { + bytes: bytes.clone(), + mime: mime.clone(), + }); + } + _ => tokens.extend_from_slice(leaf.token_ids()), + }, + AstNode::Branch { role, children, .. } => { + tokens.push(tokenizer::IM_START); + tokens.extend(tokenizer::encode(&format!("{}\n", role.as_str()))); + for c in children { + wire_into(c, tokens, images); + } + tokens.push(tokenizer::IM_END); + tokens.extend(tokenizer::encode("\n")); + } + } +} + +impl ContextState { + /// Assemble the prompt in wire form: token stream with a single + /// `<|image_pad|>` per image (vLLM expands back to N), plus the list + /// of images to send as multi_modal_data. + pub fn wire_prompt(&self) -> (Vec, Vec) { + let mut tokens = Vec::new(); + let mut images = Vec::new(); + for section in self.sections() { + for node in section { + wire_into(node, &mut tokens, &mut images); + } + } + (tokens, images) + } +} + impl ContextState { fn section_mut(&mut self, section: Section) -> &mut Vec { match section { @@ -1531,6 +1583,34 @@ mod tests { assert!(rendered.ends_with("<|vision_end|>")); } + #[test] + fn test_wire_prompt_collapses_image_pads() { + let mut ctx = ContextState::new(); + ctx.push_no_log(Section::Conversation, AstNode::branch(Role::User, vec![ + AstNode::content("look:"), + AstNode::image(vec![0xDE, 0xAD], "image/png", 512, 512), + ])); + + // AST side: N image_pads + bookends, full budget accounting. + let full = ctx.token_ids(); + let n_image_pads_full = full.iter() + .filter(|&&t| t == tokenizer::IMAGE_PAD).count(); + assert_eq!(n_image_pads_full, qwen3_image_token_count(512, 512) as usize); + + // Wire side: single image_pad, bytes moved to images list. + let (wire, images) = ctx.wire_prompt(); + let n_image_pads_wire = wire.iter() + .filter(|&&t| t == tokenizer::IMAGE_PAD).count(); + assert_eq!(n_image_pads_wire, 1); + assert_eq!(images.len(), 1); + assert_eq!(images[0].bytes, vec![0xDE, 0xAD]); + assert_eq!(images[0].mime, "image/png"); + + // vision_start/vision_end bookends are preserved in wire form. + assert_eq!(wire.iter().filter(|&&t| t == tokenizer::VISION_START).count(), 1); + assert_eq!(wire.iter().filter(|&&t| t == tokenizer::VISION_END).count(), 1); + } + #[test] fn test_image_serde_roundtrip() { let node = AstNode::image(vec![0xDE, 0xAD, 0xBE, 0xEF], "image/png", 64, 64); diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 5368db6..cb50568 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -285,16 +285,23 @@ impl Agent { } pub async fn assemble_prompt_tokens(&self) -> Vec { + self.assemble_prompt().await.0 + } + + /// Assemble a ready-to-send prompt: token stream in wire form (each + /// image collapsed to a single `<|image_pad|>`) paired with the + /// images to attach as multi_modal_data. + pub async fn assemble_prompt(&self) -> (Vec, Vec) { let ctx = self.context.lock().await; let st = self.state.lock().await; - let mut tokens = ctx.token_ids(); + let (mut tokens, images) = ctx.wire_prompt(); tokens.push(tokenizer::IM_START); if st.think_native { tokens.extend(tokenizer::encode("assistant\n\n")); } else { tokens.extend(tokenizer::encode("assistant\n")); } - tokens + (tokens, images) } /// Rebuild the tools section of the system prompt from the current tools list. @@ -354,10 +361,11 @@ impl Agent { let _thinking = start_activity(&agent, "thinking...").await; let (rx, _stream_guard) = { - let prompt_tokens = agent.assemble_prompt_tokens().await; + let (prompt_tokens, images) = agent.assemble_prompt().await; let st = agent.state.lock().await; - agent.client.stream_completion( + agent.client.stream_completion_mm( &prompt_tokens, + &images, api::SamplingParams { temperature: st.temperature, top_p: st.top_p,