agent: send images as multi_modal_data on completion requests
Split the prompt assembly into two forms: the AST keeps the fully-expanded representation (N image_pads per image, for accurate context budget accounting), while the request wire form collapses each image to a single <|image_pad|> bookended by vision_start/end and ships the raw bytes out-of-band as a base64 data URI in a new `multi_modal_data.image` field on /v1/completions. vLLM's Qwen3VL processor uses PromptReplacement with target=single <|image_pad|> and replacement=N image_pads, so the wire-form matches what the processor expects and it re-expands to N server-side. Server side needs /v1/completions to accept multi_modal_data for this to land images end-to-end — that's the next piece. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
91106deaa1
commit
204ba5570a
3 changed files with 115 additions and 5 deletions
|
|
@ -78,18 +78,31 @@ impl ApiClient {
|
||||||
prompt_tokens: &[u32],
|
prompt_tokens: &[u32],
|
||||||
sampling: SamplingParams,
|
sampling: SamplingParams,
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
|
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
||||||
|
self.stream_completion_mm(prompt_tokens, &[], sampling, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn stream_completion_mm(
|
||||||
|
&self,
|
||||||
|
prompt_tokens: &[u32],
|
||||||
|
images: &[super::context::WireImage],
|
||||||
|
sampling: SamplingParams,
|
||||||
|
priority: Option<i32>,
|
||||||
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
let api_key = self.api_key.clone();
|
let api_key = self.api_key.clone();
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
let prompt_tokens = prompt_tokens.to_vec();
|
let prompt_tokens = prompt_tokens.to_vec();
|
||||||
|
let images: Vec<(Vec<u8>, String)> = images.iter()
|
||||||
|
.map(|i| (i.bytes.clone(), i.mime.clone()))
|
||||||
|
.collect();
|
||||||
let base_url = self.base_url.clone();
|
let base_url = self.base_url.clone();
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
let result = stream_completions(
|
let result = stream_completions(
|
||||||
&client, &base_url, &api_key, &model,
|
&client, &base_url, &api_key, &model,
|
||||||
&prompt_tokens, &tx, sampling, priority,
|
&prompt_tokens, &images, &tx, sampling, priority,
|
||||||
).await;
|
).await;
|
||||||
if let Err(e) = result {
|
if let Err(e) = result {
|
||||||
let _ = tx.send(StreamToken::Error(e.to_string()));
|
let _ = tx.send(StreamToken::Error(e.to_string()));
|
||||||
|
|
@ -110,6 +123,7 @@ async fn stream_completions(
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
prompt_tokens: &[u32],
|
prompt_tokens: &[u32],
|
||||||
|
images: &[(Vec<u8>, String)],
|
||||||
tx: &mpsc::UnboundedSender<StreamToken>,
|
tx: &mpsc::UnboundedSender<StreamToken>,
|
||||||
sampling: SamplingParams,
|
sampling: SamplingParams,
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
|
|
@ -126,6 +140,14 @@ async fn stream_completions(
|
||||||
"skip_special_tokens": false,
|
"skip_special_tokens": false,
|
||||||
"stop_token_ids": [super::tokenizer::IM_END],
|
"stop_token_ids": [super::tokenizer::IM_END],
|
||||||
});
|
});
|
||||||
|
if !images.is_empty() {
|
||||||
|
use base64::Engine;
|
||||||
|
let b64 = base64::engine::general_purpose::STANDARD;
|
||||||
|
let uris: Vec<String> = images.iter()
|
||||||
|
.map(|(bytes, mime)| format!("data:{};base64,{}", mime, b64.encode(bytes)))
|
||||||
|
.collect();
|
||||||
|
request["multi_modal_data"] = serde_json::json!({ "image": uris });
|
||||||
|
}
|
||||||
if let Some(p) = priority {
|
if let Some(p) = priority {
|
||||||
request["priority"] = serde_json::json!(p);
|
request["priority"] = serde_json::json!(p);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -884,6 +884,58 @@ impl Ast for ContextState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An image collected from the AST for a request body. The AST stores
|
||||||
|
/// the pre-expanded token form (N image_pads) for accurate budget
|
||||||
|
/// accounting; the wire form collapses each Image to a single
|
||||||
|
/// `<|image_pad|>` between vision bookends and ships the bytes
|
||||||
|
/// separately as multi_modal_data.
|
||||||
|
pub struct WireImage {
|
||||||
|
pub bytes: Vec<u8>,
|
||||||
|
pub mime: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>) {
|
||||||
|
match node {
|
||||||
|
AstNode::Leaf(leaf) => match leaf.body() {
|
||||||
|
NodeBody::Image { bytes, mime, .. } => {
|
||||||
|
tokens.push(tokenizer::VISION_START);
|
||||||
|
tokens.push(tokenizer::IMAGE_PAD);
|
||||||
|
tokens.push(tokenizer::VISION_END);
|
||||||
|
images.push(WireImage {
|
||||||
|
bytes: bytes.clone(),
|
||||||
|
mime: mime.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => tokens.extend_from_slice(leaf.token_ids()),
|
||||||
|
},
|
||||||
|
AstNode::Branch { role, children, .. } => {
|
||||||
|
tokens.push(tokenizer::IM_START);
|
||||||
|
tokens.extend(tokenizer::encode(&format!("{}\n", role.as_str())));
|
||||||
|
for c in children {
|
||||||
|
wire_into(c, tokens, images);
|
||||||
|
}
|
||||||
|
tokens.push(tokenizer::IM_END);
|
||||||
|
tokens.extend(tokenizer::encode("\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ContextState {
|
||||||
|
/// Assemble the prompt in wire form: token stream with a single
|
||||||
|
/// `<|image_pad|>` per image (vLLM expands back to N), plus the list
|
||||||
|
/// of images to send as multi_modal_data.
|
||||||
|
pub fn wire_prompt(&self) -> (Vec<u32>, Vec<WireImage>) {
|
||||||
|
let mut tokens = Vec::new();
|
||||||
|
let mut images = Vec::new();
|
||||||
|
for section in self.sections() {
|
||||||
|
for node in section {
|
||||||
|
wire_into(node, &mut tokens, &mut images);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(tokens, images)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ContextState {
|
impl ContextState {
|
||||||
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
fn section_mut(&mut self, section: Section) -> &mut Vec<AstNode> {
|
||||||
match section {
|
match section {
|
||||||
|
|
@ -1531,6 +1583,34 @@ mod tests {
|
||||||
assert!(rendered.ends_with("<|vision_end|>"));
|
assert!(rendered.ends_with("<|vision_end|>"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wire_prompt_collapses_image_pads() {
|
||||||
|
let mut ctx = ContextState::new();
|
||||||
|
ctx.push_no_log(Section::Conversation, AstNode::branch(Role::User, vec![
|
||||||
|
AstNode::content("look:"),
|
||||||
|
AstNode::image(vec![0xDE, 0xAD], "image/png", 512, 512),
|
||||||
|
]));
|
||||||
|
|
||||||
|
// AST side: N image_pads + bookends, full budget accounting.
|
||||||
|
let full = ctx.token_ids();
|
||||||
|
let n_image_pads_full = full.iter()
|
||||||
|
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
|
||||||
|
assert_eq!(n_image_pads_full, qwen3_image_token_count(512, 512) as usize);
|
||||||
|
|
||||||
|
// Wire side: single image_pad, bytes moved to images list.
|
||||||
|
let (wire, images) = ctx.wire_prompt();
|
||||||
|
let n_image_pads_wire = wire.iter()
|
||||||
|
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
|
||||||
|
assert_eq!(n_image_pads_wire, 1);
|
||||||
|
assert_eq!(images.len(), 1);
|
||||||
|
assert_eq!(images[0].bytes, vec![0xDE, 0xAD]);
|
||||||
|
assert_eq!(images[0].mime, "image/png");
|
||||||
|
|
||||||
|
// vision_start/vision_end bookends are preserved in wire form.
|
||||||
|
assert_eq!(wire.iter().filter(|&&t| t == tokenizer::VISION_START).count(), 1);
|
||||||
|
assert_eq!(wire.iter().filter(|&&t| t == tokenizer::VISION_END).count(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_image_serde_roundtrip() {
|
fn test_image_serde_roundtrip() {
|
||||||
let node = AstNode::image(vec![0xDE, 0xAD, 0xBE, 0xEF], "image/png", 64, 64);
|
let node = AstNode::image(vec![0xDE, 0xAD, 0xBE, 0xEF], "image/png", 64, 64);
|
||||||
|
|
|
||||||
|
|
@ -285,16 +285,23 @@ impl Agent {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn assemble_prompt_tokens(&self) -> Vec<u32> {
|
pub async fn assemble_prompt_tokens(&self) -> Vec<u32> {
|
||||||
|
self.assemble_prompt().await.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Assemble a ready-to-send prompt: token stream in wire form (each
|
||||||
|
/// image collapsed to a single `<|image_pad|>`) paired with the
|
||||||
|
/// images to attach as multi_modal_data.
|
||||||
|
pub async fn assemble_prompt(&self) -> (Vec<u32>, Vec<context::WireImage>) {
|
||||||
let ctx = self.context.lock().await;
|
let ctx = self.context.lock().await;
|
||||||
let st = self.state.lock().await;
|
let st = self.state.lock().await;
|
||||||
let mut tokens = ctx.token_ids();
|
let (mut tokens, images) = ctx.wire_prompt();
|
||||||
tokens.push(tokenizer::IM_START);
|
tokens.push(tokenizer::IM_START);
|
||||||
if st.think_native {
|
if st.think_native {
|
||||||
tokens.extend(tokenizer::encode("assistant\n<think>\n"));
|
tokens.extend(tokenizer::encode("assistant\n<think>\n"));
|
||||||
} else {
|
} else {
|
||||||
tokens.extend(tokenizer::encode("assistant\n"));
|
tokens.extend(tokenizer::encode("assistant\n"));
|
||||||
}
|
}
|
||||||
tokens
|
(tokens, images)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rebuild the tools section of the system prompt from the current tools list.
|
/// Rebuild the tools section of the system prompt from the current tools list.
|
||||||
|
|
@ -354,10 +361,11 @@ impl Agent {
|
||||||
let _thinking = start_activity(&agent, "thinking...").await;
|
let _thinking = start_activity(&agent, "thinking...").await;
|
||||||
|
|
||||||
let (rx, _stream_guard) = {
|
let (rx, _stream_guard) = {
|
||||||
let prompt_tokens = agent.assemble_prompt_tokens().await;
|
let (prompt_tokens, images) = agent.assemble_prompt().await;
|
||||||
let st = agent.state.lock().await;
|
let st = agent.state.lock().await;
|
||||||
agent.client.stream_completion(
|
agent.client.stream_completion_mm(
|
||||||
&prompt_tokens,
|
&prompt_tokens,
|
||||||
|
&images,
|
||||||
api::SamplingParams {
|
api::SamplingParams {
|
||||||
temperature: st.temperature,
|
temperature: st.temperature,
|
||||||
top_p: st.top_p,
|
top_p: st.top_p,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue