agent: unify prompt assembly across agent and learn paths
wire_prompt() gains a conv_range and a skip closure, and returns the assistant-message token ranges needed by the scoring path. The agent path passes 0..len + |_| false and ignores the ranges. Memory-ablation scoring and candidate generation pass a prefix range + a predicate (e.g. is_memory_node, or |n| memory_key(n) == Some(key)). This deletes subconscious/learn.rs's build_token_ids, its private Filter enum, and the is_memory/memory_key duplicates — the walk over context sections now has one home. Adding a section or changing section order in the agent path won't silently drift away from what scoring sees. call_score forwards multi_modal_data when the wire-form prompt contains images. generate_alternate switches to stream_completion_mm and passes the same images. Scoring on image-bearing contexts now sends wire form (1 image_pad + image data) instead of expanded image_pads with no image data; text-only contexts are bit-identical. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
0d1044c2e8
commit
eea7de4753
3 changed files with 98 additions and 108 deletions
|
|
@ -920,19 +920,67 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn memory_key(node: &AstNode) -> Option<&str> {
|
||||||
|
match node {
|
||||||
|
AstNode::Leaf(leaf) => match leaf.body() {
|
||||||
|
NodeBody::Memory { key, .. } => Some(key),
|
||||||
|
_ => None,
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_memory_node(node: &AstNode) -> bool {
|
||||||
|
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
|
||||||
|
}
|
||||||
|
|
||||||
impl ContextState {
|
impl ContextState {
|
||||||
/// Assemble the prompt in wire form: token stream with a single
|
/// Assemble the prompt in wire form: token stream with a single
|
||||||
/// `<|image_pad|>` per image (vLLM expands back to N), plus the list
|
/// `<|image_pad|>` per image (vLLM expands back to N), plus the list
|
||||||
/// of images to send as multi_modal_data.
|
/// of images to send as multi_modal_data, plus the (start, end) token
|
||||||
pub fn wire_prompt(&self) -> (Vec<u32>, Vec<WireImage>) {
|
/// positions of each assistant message branch emitted (used by the
|
||||||
|
/// scoring path as `score_ranges`).
|
||||||
|
///
|
||||||
|
/// `conv_range` selects a prefix (or any sub-range) of conversation
|
||||||
|
/// entries to include — the agent path passes `0..conversation().len()`;
|
||||||
|
/// scoring / candidate generation pass a prefix up to the entry of
|
||||||
|
/// interest.
|
||||||
|
///
|
||||||
|
/// `skip` is a predicate applied to identity and conversation entries;
|
||||||
|
/// returning true drops the node from the prompt. The agent path passes
|
||||||
|
/// `|_| false`; memory-ablation scoring passes e.g. `is_memory_node` or
|
||||||
|
/// `|n| memory_key(n) == Some(key)`.
|
||||||
|
pub fn wire_prompt<F>(
|
||||||
|
&self,
|
||||||
|
conv_range: std::ops::Range<usize>,
|
||||||
|
mut skip: F,
|
||||||
|
) -> (Vec<u32>, Vec<WireImage>, Vec<(usize, usize)>)
|
||||||
|
where F: FnMut(&AstNode) -> bool,
|
||||||
|
{
|
||||||
let mut tokens = Vec::new();
|
let mut tokens = Vec::new();
|
||||||
let mut images = Vec::new();
|
let mut images = Vec::new();
|
||||||
for section in self.sections() {
|
let mut assistant_ranges = Vec::new();
|
||||||
for node in section {
|
|
||||||
|
for node in self.system() {
|
||||||
wire_into(node, &mut tokens, &mut images);
|
wire_into(node, &mut tokens, &mut images);
|
||||||
}
|
}
|
||||||
|
for node in self.identity() {
|
||||||
|
if skip(node) { continue; }
|
||||||
|
wire_into(node, &mut tokens, &mut images);
|
||||||
}
|
}
|
||||||
(tokens, images)
|
for node in self.journal() {
|
||||||
|
wire_into(node, &mut tokens, &mut images);
|
||||||
|
}
|
||||||
|
for node in &self.conversation()[conv_range] {
|
||||||
|
if skip(node) { continue; }
|
||||||
|
let start = tokens.len();
|
||||||
|
let is_asst = matches!(node, AstNode::Branch { role: Role::Assistant, .. });
|
||||||
|
wire_into(node, &mut tokens, &mut images);
|
||||||
|
if is_asst {
|
||||||
|
assistant_ranges.push((start, tokens.len()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(tokens, images, assistant_ranges)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1598,7 +1646,7 @@ mod tests {
|
||||||
assert_eq!(n_image_pads_full, qwen3_image_token_count(512, 512) as usize);
|
assert_eq!(n_image_pads_full, qwen3_image_token_count(512, 512) as usize);
|
||||||
|
|
||||||
// Wire side: single image_pad, bytes moved to images list.
|
// Wire side: single image_pad, bytes moved to images list.
|
||||||
let (wire, images) = ctx.wire_prompt();
|
let (wire, images, _) = ctx.wire_prompt(0..ctx.conversation().len(), |_| false);
|
||||||
let n_image_pads_wire = wire.iter()
|
let n_image_pads_wire = wire.iter()
|
||||||
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
|
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
|
||||||
assert_eq!(n_image_pads_wire, 1);
|
assert_eq!(n_image_pads_wire, 1);
|
||||||
|
|
|
||||||
|
|
@ -294,7 +294,8 @@ impl Agent {
|
||||||
pub async fn assemble_prompt(&self) -> (Vec<u32>, Vec<context::WireImage>) {
|
pub async fn assemble_prompt(&self) -> (Vec<u32>, Vec<context::WireImage>) {
|
||||||
let ctx = self.context.lock().await;
|
let ctx = self.context.lock().await;
|
||||||
let st = self.state.lock().await;
|
let st = self.state.lock().await;
|
||||||
let (mut tokens, images) = ctx.wire_prompt();
|
let (mut tokens, images, _) =
|
||||||
|
ctx.wire_prompt(0..ctx.conversation().len(), |_| false);
|
||||||
tokens.push(tokenizer::IM_START);
|
tokens.push(tokenizer::IM_START);
|
||||||
if st.think_native {
|
if st.think_native {
|
||||||
tokens.extend(tokenizer::encode("assistant\n<think>\n"));
|
tokens.extend(tokenizer::encode("assistant\n<think>\n"));
|
||||||
|
|
|
||||||
|
|
@ -15,95 +15,17 @@
|
||||||
// hasn't internalized. 2 API calls.
|
// hasn't internalized. 2 API calls.
|
||||||
|
|
||||||
use crate::agent::api::ApiClient;
|
use crate::agent::api::ApiClient;
|
||||||
use crate::agent::context::{AstNode, Ast, NodeBody, ContextState, Role};
|
use crate::agent::context::{
|
||||||
|
Ast, AstNode, ContextState, Role, WireImage, is_memory_node, memory_key,
|
||||||
|
};
|
||||||
use crate::agent::tokenizer;
|
use crate::agent::tokenizer;
|
||||||
|
|
||||||
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||||||
|
|
||||||
// ── Message building ────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// What to filter when building the message array for scoring.
|
|
||||||
#[allow(dead_code)]
|
|
||||||
enum Filter<'a> {
|
|
||||||
None,
|
|
||||||
SkipIndex(usize),
|
|
||||||
SkipKey(&'a str),
|
|
||||||
SkipAllMemories,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_memory(node: &AstNode) -> bool {
|
|
||||||
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn memory_key(node: &AstNode) -> Option<&str> {
|
|
||||||
match node {
|
|
||||||
AstNode::Leaf(leaf) => match leaf.body() {
|
|
||||||
NodeBody::Memory { key, .. } => Some(key),
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_assistant(node: &AstNode) -> bool {
|
fn is_assistant(node: &AstNode) -> bool {
|
||||||
matches!(node, AstNode::Branch { role: Role::Assistant, .. })
|
matches!(node, AstNode::Branch { role: Role::Assistant, .. })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a token ID array for a scoring call.
|
|
||||||
///
|
|
||||||
/// Includes all sections up to and including conversation entries in
|
|
||||||
/// `range`, with `filter` applied to conversation entries.
|
|
||||||
///
|
|
||||||
/// Returns (token_ids, assistant_ranges) where assistant_ranges are
|
|
||||||
/// (start, end) token positions for each assistant message.
|
|
||||||
fn build_token_ids(
|
|
||||||
context: &ContextState,
|
|
||||||
range: std::ops::Range<usize>,
|
|
||||||
filter: Filter,
|
|
||||||
) -> (Vec<u32>, Vec<(usize, usize)>) {
|
|
||||||
use crate::agent::context::Ast;
|
|
||||||
let mut ids = Vec::new();
|
|
||||||
let mut assistant_ranges = Vec::new();
|
|
||||||
|
|
||||||
for node in context.system() {
|
|
||||||
ids.extend(node.token_ids());
|
|
||||||
}
|
|
||||||
// Identity nodes can be filtered by key for scoring
|
|
||||||
for node in context.identity() {
|
|
||||||
let skip = match &filter {
|
|
||||||
Filter::SkipKey(key) => memory_key(node) == Some(*key),
|
|
||||||
Filter::SkipAllMemories => is_memory(node),
|
|
||||||
_ => false,
|
|
||||||
};
|
|
||||||
if !skip {
|
|
||||||
ids.extend(node.token_ids());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for node in context.journal() {
|
|
||||||
ids.extend(node.token_ids());
|
|
||||||
}
|
|
||||||
let entries = context.conversation();
|
|
||||||
for i in range {
|
|
||||||
let node = &entries[i];
|
|
||||||
let skip = match &filter {
|
|
||||||
Filter::None => false,
|
|
||||||
Filter::SkipIndex(idx) => i == *idx,
|
|
||||||
Filter::SkipKey(key) => memory_key(node) == Some(*key),
|
|
||||||
Filter::SkipAllMemories => is_memory(node),
|
|
||||||
};
|
|
||||||
if skip { continue; }
|
|
||||||
|
|
||||||
// Track assistant message boundaries
|
|
||||||
let is_asst = is_assistant(node);
|
|
||||||
let start = ids.len();
|
|
||||||
ids.extend(node.token_ids());
|
|
||||||
if is_asst {
|
|
||||||
assistant_ranges.push((start, ids.len()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(ids, assistant_ranges)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Score API ───────────────────────────────────────────────────
|
// ── Score API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
|
|
@ -126,6 +48,7 @@ async fn call_score(
|
||||||
http: &crate::agent::api::http::HttpClient,
|
http: &crate::agent::api::http::HttpClient,
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
prompt: &[u32],
|
prompt: &[u32],
|
||||||
|
images: &[WireImage],
|
||||||
ranges: &[(usize, usize)],
|
ranges: &[(usize, usize)],
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
) -> anyhow::Result<Vec<ScoreResult>> {
|
) -> anyhow::Result<Vec<ScoreResult>> {
|
||||||
|
|
@ -141,6 +64,14 @@ async fn call_score(
|
||||||
"score_ranges": ranges,
|
"score_ranges": ranges,
|
||||||
"logprobs": 1,
|
"logprobs": 1,
|
||||||
});
|
});
|
||||||
|
if !images.is_empty() {
|
||||||
|
use base64::Engine;
|
||||||
|
let b64 = base64::engine::general_purpose::STANDARD;
|
||||||
|
let uris: Vec<String> = images.iter()
|
||||||
|
.map(|img| format!("data:{};base64,{}", img.mime, b64.encode(&img.bytes)))
|
||||||
|
.collect();
|
||||||
|
body["multi_modal_data"] = serde_json::json!({ "image": uris });
|
||||||
|
}
|
||||||
if let Some(p) = priority {
|
if let Some(p) = priority {
|
||||||
body["priority"] = serde_json::json!(p);
|
body["priority"] = serde_json::json!(p);
|
||||||
}
|
}
|
||||||
|
|
@ -178,18 +109,24 @@ fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Score two message sets and return total divergence.
|
/// Score two message sets and return total divergence.
|
||||||
async fn score_divergence(
|
async fn score_divergence<F>(
|
||||||
http: &crate::agent::api::http::HttpClient,
|
http: &crate::agent::api::http::HttpClient,
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
context: &ContextState,
|
context: &ContextState,
|
||||||
range: std::ops::Range<usize>,
|
range: std::ops::Range<usize>,
|
||||||
filter: Filter<'_>,
|
skip: F,
|
||||||
priority: Option<i32>,
|
priority: Option<i32>,
|
||||||
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)> {
|
) -> anyhow::Result<(Vec<f64>, Vec<ScoreResult>)>
|
||||||
let (baseline_tokens, baseline_ranges) = build_token_ids(context, range.clone(), Filter::None);
|
where F: FnMut(&AstNode) -> bool,
|
||||||
let (without_tokens, without_ranges) = build_token_ids(context, range, filter);
|
{
|
||||||
let baseline = call_score(http, client, &baseline_tokens, &baseline_ranges, priority).await?;
|
let (baseline_tokens, baseline_images, baseline_ranges) =
|
||||||
let without = call_score(http, client, &without_tokens, &without_ranges, priority).await?;
|
context.wire_prompt(range.clone(), |_| false);
|
||||||
|
let (without_tokens, without_images, without_ranges) =
|
||||||
|
context.wire_prompt(range, skip);
|
||||||
|
let baseline = call_score(http, client, &baseline_tokens, &baseline_images,
|
||||||
|
&baseline_ranges, priority).await?;
|
||||||
|
let without = call_score(http, client, &without_tokens, &without_images,
|
||||||
|
&without_ranges, priority).await?;
|
||||||
let divs = divergence(&baseline, &without);
|
let divs = divergence(&baseline, &without);
|
||||||
Ok((divs, baseline))
|
Ok((divs, baseline))
|
||||||
}
|
}
|
||||||
|
|
@ -228,21 +165,22 @@ pub async fn score_memories(
|
||||||
let http = http_client();
|
let http = http_client();
|
||||||
|
|
||||||
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
|
||||||
let (baseline_tokens, baseline_ranges) = {
|
let (baseline_tokens, baseline_images, baseline_ranges) = {
|
||||||
let ctx = agent.context.lock().await;
|
let ctx = agent.context.lock().await;
|
||||||
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::None)
|
ctx.wire_prompt(0..ctx.conversation().len(), |_| false)
|
||||||
};
|
};
|
||||||
let baseline = call_score(&http, client, &baseline_tokens, &baseline_ranges, Some(5)).await?;
|
let baseline = call_score(&http, client, &baseline_tokens, &baseline_images,
|
||||||
|
&baseline_ranges, Some(5)).await?;
|
||||||
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
|
||||||
|
|
||||||
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
for (mem_idx, key) in memory_keys.iter().enumerate() {
|
||||||
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
|
activity.update(format!("scoring: {}/{}", mem_idx + 1, total)).await;
|
||||||
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
|
dbglog!("[scoring-full] {}/{}: {}", mem_idx + 1, total, key);
|
||||||
let (tokens, ranges) = {
|
let (tokens, images, ranges) = {
|
||||||
let ctx = agent.context.lock().await;
|
let ctx = agent.context.lock().await;
|
||||||
build_token_ids(&ctx, 0..ctx.conversation().len(), Filter::SkipKey(key))
|
ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str()))
|
||||||
};
|
};
|
||||||
let row = match call_score(&http, client, &tokens, &ranges, Some(5)).await {
|
let row = match call_score(&http, client, &tokens, &images, &ranges, Some(5)).await {
|
||||||
Ok(without) => {
|
Ok(without) => {
|
||||||
let divs = divergence(&baseline, &without);
|
let divs = divergence(&baseline, &without);
|
||||||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||||||
|
|
@ -326,7 +264,8 @@ pub async fn score_memory(
|
||||||
}
|
}
|
||||||
|
|
||||||
let http = http_client();
|
let http = http_client();
|
||||||
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await?;
|
let (divs, _) = score_divergence(&http, client, context, range,
|
||||||
|
|n| memory_key(n) == Some(key), Some(5)).await?;
|
||||||
|
|
||||||
Ok(divs.iter().sum())
|
Ok(divs.iter().sum())
|
||||||
}
|
}
|
||||||
|
|
@ -418,7 +357,8 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
|
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
|
||||||
match score_divergence(&http, client, context, range, Filter::SkipKey(key), Some(5)).await {
|
match score_divergence(&http, client, context, range,
|
||||||
|
|n| memory_key(n) == Some(key), Some(5)).await {
|
||||||
Ok((divs, _)) => {
|
Ok((divs, _)) => {
|
||||||
let n_responses = divs.len();
|
let n_responses = divs.len();
|
||||||
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
|
||||||
|
|
@ -464,7 +404,7 @@ pub async fn score_finetune(
|
||||||
}
|
}
|
||||||
|
|
||||||
let http = http_client();
|
let http = http_client();
|
||||||
let (divs, _) = score_divergence(&http, client, context, range, Filter::SkipAllMemories, Some(5)).await?;
|
let (divs, _) = score_divergence(&http, client, context, range, is_memory_node, Some(5)).await?;
|
||||||
|
|
||||||
let mut results: Vec<(usize, f64)> = response_positions.iter()
|
let mut results: Vec<(usize, f64)> = response_positions.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
|
@ -593,7 +533,7 @@ pub async fn score_finetune_candidates(
|
||||||
let prior_context = render_prior_context(entries, entry_idx, 2);
|
let prior_context = render_prior_context(entries, entry_idx, 2);
|
||||||
|
|
||||||
// Build token IDs: context = everything before response, continuation = response.
|
// Build token IDs: context = everything before response, continuation = response.
|
||||||
let (context_ids, _) = build_token_ids(context, 0..entry_idx, Filter::None);
|
let (context_ids, _, _) = context.wire_prompt(0..entry_idx, |_| false);
|
||||||
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
|
let continuation_ids: Vec<u32> = node.token_ids().into_iter().collect();
|
||||||
|
|
||||||
candidates.push(FinetuneCandidate {
|
candidates.push(FinetuneCandidate {
|
||||||
|
|
@ -636,7 +576,8 @@ async fn generate_alternate(
|
||||||
use crate::agent::api::{SamplingParams, StreamToken};
|
use crate::agent::api::{SamplingParams, StreamToken};
|
||||||
|
|
||||||
// Build context tokens without memories, up to the response
|
// Build context tokens without memories, up to the response
|
||||||
let (mut prompt, _) = build_token_ids(context, 0..entry_idx, Filter::SkipAllMemories);
|
let (mut prompt, images, _) =
|
||||||
|
context.wire_prompt(0..entry_idx, is_memory_node);
|
||||||
|
|
||||||
// Add assistant turn start
|
// Add assistant turn start
|
||||||
prompt.push(tokenizer::IM_START);
|
prompt.push(tokenizer::IM_START);
|
||||||
|
|
@ -648,7 +589,7 @@ async fn generate_alternate(
|
||||||
top_p: 0.95,
|
top_p: 0.95,
|
||||||
top_k: 20,
|
top_k: 20,
|
||||||
};
|
};
|
||||||
let (mut rx, _guard) = client.stream_completion(&prompt, sampling, Some(-5));
|
let (mut rx, _guard) = client.stream_completion_mm(&prompt, &images, sampling, Some(-5));
|
||||||
|
|
||||||
let mut tokens = Vec::new();
|
let mut tokens = Vec::new();
|
||||||
while let Some(tok) = rx.recv().await {
|
while let Some(tok) = rx.recv().await {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue