diff --git a/src/agent/context.rs b/src/agent/context.rs index 0082f06..38127d5 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -920,19 +920,67 @@ fn wire_into(node: &AstNode, tokens: &mut Vec, images: &mut Vec) } } +pub fn memory_key(node: &AstNode) -> Option<&str> { + match node { + AstNode::Leaf(leaf) => match leaf.body() { + NodeBody::Memory { key, .. } => Some(key), + _ => None, + }, + _ => None, + } +} + +pub fn is_memory_node(node: &AstNode) -> bool { + matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. })) +} + impl ContextState { /// Assemble the prompt in wire form: token stream with a single /// `<|image_pad|>` per image (vLLM expands back to N), plus the list - /// of images to send as multi_modal_data. - pub fn wire_prompt(&self) -> (Vec, Vec) { + /// of images to send as multi_modal_data, plus the (start, end) token + /// positions of each assistant message branch emitted (used by the + /// scoring path as `score_ranges`). + /// + /// `conv_range` selects a prefix (or any sub-range) of conversation + /// entries to include — the agent path passes `0..conversation().len()`; + /// scoring / candidate generation pass a prefix up to the entry of + /// interest. + /// + /// `skip` is a predicate applied to identity and conversation entries; + /// returning true drops the node from the prompt. The agent path passes + /// `|_| false`; memory-ablation scoring passes e.g. `is_memory_node` or + /// `|n| memory_key(n) == Some(key)`. + pub fn wire_prompt( + &self, + conv_range: std::ops::Range, + mut skip: F, + ) -> (Vec, Vec, Vec<(usize, usize)>) + where F: FnMut(&AstNode) -> bool, + { let mut tokens = Vec::new(); let mut images = Vec::new(); - for section in self.sections() { - for node in section { - wire_into(node, &mut tokens, &mut images); + let mut assistant_ranges = Vec::new(); + + for node in self.system() { + wire_into(node, &mut tokens, &mut images); + } + for node in self.identity() { + if skip(node) { continue; } + wire_into(node, &mut tokens, &mut images); + } + for node in self.journal() { + wire_into(node, &mut tokens, &mut images); + } + for node in &self.conversation()[conv_range] { + if skip(node) { continue; } + let start = tokens.len(); + let is_asst = matches!(node, AstNode::Branch { role: Role::Assistant, .. }); + wire_into(node, &mut tokens, &mut images); + if is_asst { + assistant_ranges.push((start, tokens.len())); } } - (tokens, images) + (tokens, images, assistant_ranges) } } @@ -1598,7 +1646,7 @@ mod tests { assert_eq!(n_image_pads_full, qwen3_image_token_count(512, 512) as usize); // Wire side: single image_pad, bytes moved to images list. - let (wire, images) = ctx.wire_prompt(); + let (wire, images, _) = ctx.wire_prompt(0..ctx.conversation().len(), |_| false); let n_image_pads_wire = wire.iter() .filter(|&&t| t == tokenizer::IMAGE_PAD).count(); assert_eq!(n_image_pads_wire, 1); diff --git a/src/agent/mod.rs b/src/agent/mod.rs index bc62955..436dda3 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -294,7 +294,8 @@ impl Agent { pub async fn assemble_prompt(&self) -> (Vec, Vec) { let ctx = self.context.lock().await; let st = self.state.lock().await; - let (mut tokens, images) = ctx.wire_prompt(); + let (mut tokens, images, _) = + ctx.wire_prompt(0..ctx.conversation().len(), |_| false); tokens.push(tokenizer::IM_START); if st.think_native { tokens.extend(tokenizer::encode("assistant\n\n")); diff --git a/src/subconscious/learn.rs b/src/subconscious/learn.rs index 7137211..26c854b 100644 --- a/src/subconscious/learn.rs +++ b/src/subconscious/learn.rs @@ -15,95 +15,17 @@ // hasn't internalized. 2 API calls. use crate::agent::api::ApiClient; -use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role}; +use crate::agent::context::{ + Ast, AstNode, ContextState, Role, WireImage, is_memory_node, memory_key, +}; use crate::agent::tokenizer; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); -// ── Message building ──────────────────────────────────────────── - -/// What to filter when building the message array for scoring. -#[allow(dead_code)] -enum Filter<'a> { - None, - SkipIndex(usize), - SkipKey(&'a str), - SkipAllMemories, -} - -fn is_memory(node: &AstNode) -> bool { - matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. })) -} - -fn memory_key(node: &AstNode) -> Option<&str> { - match node { - AstNode::Leaf(leaf) => match leaf.body() { - NodeBody::Memory { key, .. } => Some(key), - _ => None, - }, - _ => None, - } -} - fn is_assistant(node: &AstNode) -> bool { matches!(node, AstNode::Branch { role: Role::Assistant, .. }) } -/// Build a token ID array for a scoring call. -/// -/// Includes all sections up to and including conversation entries in -/// `range`, with `filter` applied to conversation entries. -/// -/// Returns (token_ids, assistant_ranges) where assistant_ranges are -/// (start, end) token positions for each assistant message. -fn build_token_ids( - context: &ContextState, - range: std::ops::Range, - filter: Filter, -) -> (Vec, Vec<(usize, usize)>) { - use crate::agent::context::Ast; - let mut ids = Vec::new(); - let mut assistant_ranges = Vec::new(); - - for node in context.system() { - ids.extend(node.token_ids()); - } - // Identity nodes can be filtered by key for scoring - for node in context.identity() { - let skip = match &filter { - Filter::SkipKey(key) => memory_key(node) == Some(*key), - Filter::SkipAllMemories => is_memory(node), - _ => false, - }; - if !skip { - ids.extend(node.token_ids()); - } - } - for node in context.journal() { - ids.extend(node.token_ids()); - } - let entries = context.conversation(); - for i in range { - let node = &entries[i]; - let skip = match &filter { - Filter::None => false, - Filter::SkipIndex(idx) => i == *idx, - Filter::SkipKey(key) => memory_key(node) == Some(*key), - Filter::SkipAllMemories => is_memory(node), - }; - if skip { continue; } - - // Track assistant message boundaries - let is_asst = is_assistant(node); - let start = ids.len(); - ids.extend(node.token_ids()); - if is_asst { - assistant_ranges.push((start, ids.len())); - } - } - (ids, assistant_ranges) -} - // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] @@ -126,6 +48,7 @@ async fn call_score( http: &crate::agent::api::http::HttpClient, client: &ApiClient, prompt: &[u32], + images: &[WireImage], ranges: &[(usize, usize)], priority: Option, ) -> anyhow::Result> { @@ -141,6 +64,14 @@ async fn call_score( "score_ranges": ranges, "logprobs": 1, }); + if !images.is_empty() { + use base64::Engine; + let b64 = base64::engine::general_purpose::STANDARD; + let uris: Vec = images.iter() + .map(|img| format!("data:{};base64,{}", img.mime, b64.encode(&img.bytes))) + .collect(); + body["multi_modal_data"] = serde_json::json!({ "image": uris }); + } if let Some(p) = priority { body["priority"] = serde_json::json!(p); } @@ -178,18 +109,24 @@ fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec { } /// Score two message sets and return total divergence. -async fn score_divergence( +async fn score_divergence( http: &crate::agent::api::http::HttpClient, client: &ApiClient, context: &ContextState, range: std::ops::Range, - filter: Filter<'_>, + skip: F, priority: Option, -) -> anyhow::Result<(Vec, Vec)> { - let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None); - let (without_tokens, without_ranges) = build_token_ids(context, range, filter); - let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?; - let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?; +) -> anyhow::Result<(Vec, Vec)> +where F: FnMut(&AstNode) -> bool, +{ + let (baseline_tokens, baseline_images, baseline_ranges) = + context.wire_prompt(range.clone(), |_| false); + let (without_tokens, without_images, without_ranges) = + context.wire_prompt(range, skip); + let baseline = call_score(http, client, &baseline_tokens, &baseline_images, + &baseline_ranges, priority).await?; + let without = call_score(http, client, &without_tokens, &without_images, + &without_ranges, priority).await?; let divs = divergence(&baseline, &without); Ok((divs, baseline)) } @@ -228,21 +165,22 @@ pub async fn score_memories( let http = http_client(); let activity = crate::agent::start_activity(agent, "scoring: baseline").await; - let (baseline_tokens, baseline_ranges) = { + let (baseline_tokens, baseline_images, baseline_ranges) = { let ctx = agent.context.lock().await; - build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None) + ctx.wire_prompt(0..ctx.conversation().len(), |_| false) }; - let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?; + let baseline = call_score(&http, client, &baseline_tokens, &baseline_images, + &baseline_ranges, Some(5)).await?; dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len()); for (mem_idx, key) in memory_keys.iter().enumerate() { activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await; dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key); - let (tokens, ranges) = { + let (tokens, images, ranges) = { let ctx = agent.context.lock().await; - build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key)) + ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str())) }; - let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await { + let row = match call_score(&http, client, &tokens, &images, &ranges, Some(5)).await { Ok(without) => { let divs = divergence(&baseline, &without); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); @@ -326,7 +264,8 @@ pub async fn score_memory( } let http = http_client(); - let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await?; + let (divs, _) = score_divergence(&http, client, context, range, + |n| memory_key(n) == Some(key), Some(5)).await?; Ok(divs.iter().sum()) } @@ -418,7 +357,8 @@ where } activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await; - match score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await { + match score_divergence(&http, client, context, range, + |n| memory_key(n) == Some(key), Some(5)).await { Ok((divs, _)) => { let n_responses = divs.len(); let max_div = divs.iter().cloned().fold(0.0f64, f64::max); @@ -464,7 +404,7 @@ pub async fn score_finetune( } let http = http_client(); - let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories, Some(5)).await?; + let (divs, _) = score_divergence(&http, client, context, range, is_memory_node, Some(5)).await?; let mut results: Vec<(usize, f64)> = response_positions.iter() .enumerate() @@ -593,7 +533,7 @@ pub async fn score_finetune_candidates( let prior_context = render_prior_context(entries, entry_idx, 2); // Build token IDs: context = everything before response, continuation = response. - let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None); + let (context_ids, _, _) = context.wire_prompt(0..entry_idx, |_| false); let continuation_ids: Vec = node.token_ids().into_iter().collect(); candidates.push(FinetuneCandidate { @@ -636,7 +576,8 @@ async fn generate_alternate( use crate::agent::api::{SamplingParams, StreamToken}; // Build context tokens without memories, up to the response - let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories); + let (mut prompt, images, _) = + context.wire_prompt(0..entry_idx, is_memory_node); // Add assistant turn start prompt.push(tokenizer::IM_START); @@ -648,7 +589,7 @@ async fn generate_alternate( top_p: 0.95, top_k: 20, }; - let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5)); + let (mut rx, _guard) = client.stream_completion_mm(&prompt, &images, sampling, Some(-5)); let mut tokens = Vec::new(); while let Some(tok) = rx.recv().await {