agent: unify prompt assembly across agent and learn paths

wire_prompt() gains a conv_range and a skip closure, and returns the
assistant-message token ranges needed by the scoring path. The agent
path passes 0..len + |_| false and ignores the ranges. Memory-ablation
scoring and candidate generation pass a prefix range + a predicate
(e.g. is_memory_node, or |n| memory_key(n) == Some(key)).

This deletes subconscious/learn.rs's build_token_ids, its private
Filter enum, and the is_memory/memory_key duplicates — the walk over
context sections now has one home. Adding a section or changing
section order in the agent path won't silently drift away from what
scoring sees.

call_score forwards multi_modal_data when the wire-form prompt
contains images. generate_alternate switches to stream_completion_mm
and passes the same images. Scoring on image-bearing contexts now
sends wire form (1 image_pad + image data) instead of expanded
image_pads with no image data; text-only contexts are bit-identical.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-17 15:16:07 -04:00
parent 0d1044c2e8
commit eea7de4753
3 changed files with 98 additions and 108 deletions

View file

@ -920,19 +920,67 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
}
}
pub fn memory_key(node: &AstNode) -> Option<&str> {
match node {
AstNode::Leaf(leaf) => match leaf.body() {
NodeBody::Memory { key, .. } => Some(key),
_ => None,
},
_ => None,
}
}
pub fn is_memory_node(node: &AstNode) -> bool {
matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. }))
}
impl ContextState {
/// Assemble the prompt in wire form: token stream with a single
/// `<|image_pad|>` per image (vLLM expands back to N), plus the list
/// of images to send as multi_modal_data.
pub fn wire_prompt(&self) -> (Vec<u32>, Vec<WireImage>) {
/// of images to send as multi_modal_data, plus the (start, end) token
/// positions of each assistant message branch emitted (used by the
/// scoring path as `score_ranges`).
///
/// `conv_range` selects a prefix (or any sub-range) of conversation
/// entries to include — the agent path passes `0..conversation().len()`;
/// scoring / candidate generation pass a prefix up to the entry of
/// interest.
///
/// `skip` is a predicate applied to identity and conversation entries;
/// returning true drops the node from the prompt. The agent path passes
/// `|_| false`; memory-ablation scoring passes e.g. `is_memory_node` or
/// `|n| memory_key(n) == Some(key)`.
pub fn wire_prompt<F>(
&self,
conv_range: std::ops::Range<usize>,
mut skip: F,
) -> (Vec<u32>, Vec<WireImage>, Vec<(usize, usize)>)
where F: FnMut(&AstNode) -> bool,
{
let mut tokens = Vec::new();
let mut images = Vec::new();
for section in self.sections() {
for node in section {
wire_into(node, &mut tokens, &mut images);
let mut assistant_ranges = Vec::new();
for node in self.system() {
wire_into(node, &mut tokens, &mut images);
}
for node in self.identity() {
if skip(node) { continue; }
wire_into(node, &mut tokens, &mut images);
}
for node in self.journal() {
wire_into(node, &mut tokens, &mut images);
}
for node in &self.conversation()[conv_range] {
if skip(node) { continue; }
let start = tokens.len();
let is_asst = matches!(node, AstNode::Branch { role: Role::Assistant, .. });
wire_into(node, &mut tokens, &mut images);
if is_asst {
assistant_ranges.push((start, tokens.len()));
}
}
(tokens, images)
(tokens, images, assistant_ranges)
}
}
@ -1598,7 +1646,7 @@ mod tests {
assert_eq!(n_image_pads_full, qwen3_image_token_count(512, 512) as usize);
// Wire side: single image_pad, bytes moved to images list.
let (wire, images) = ctx.wire_prompt();
let (wire, images, _) = ctx.wire_prompt(0..ctx.conversation().len(), |_| false);
let n_image_pads_wire = wire.iter()
.filter(|&&t| t == tokenizer::IMAGE_PAD).count();
assert_eq!(n_image_pads_wire, 1);

View file

@ -294,7 +294,8 @@ impl Agent {
pub async fn assemble_prompt(&self) -> (Vec<u32>, Vec<context::WireImage>) {
let ctx = self.context.lock().await;
let st = self.state.lock().await;
let (mut tokens, images) = ctx.wire_prompt();
let (mut tokens, images, _) =
ctx.wire_prompt(0..ctx.conversation().len(), |_| false);
tokens.push(tokenizer::IM_START);
if st.think_native {
tokens.extend(tokenizer::encode("assistant\n<think>\n"));