diff --git a/src/agent/context.rs b/src/agent/context.rs index 0a49e05..d61136f 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -154,8 +154,19 @@ pub enum AstNode { } /// The context window: four sections as Vec. -/// All mutation goes through ContextState methods to maintain the invariant -/// that token_ids on every leaf matches its rendered text. +/// +/// All mutation MUST go through `ContextState`'s public methods. Two +/// invariants ride on this: +/// 1. Every `Leaf.token_ids` matches its `body.compute_token_ids()`. +/// 2. For every `Branch { token_ids: Some(cached), .. }`, the cached +/// token stream matches what `wire_into` would produce by walking +/// `children` from scratch. Any mutation that touches a Branch's +/// children — directly or via a descendant — must clear the +/// Branch's `token_ids` so it gets recomputed on next wire-out. +/// +/// The `&mut Vec` escape hatches are intentionally NOT +/// exposed; if you find yourself wanting one, add a focused method +/// here that maintains the invariants. pub struct ContextState { system: Vec, identity: Vec, @@ -966,7 +977,33 @@ impl ContextState { pub fn identity(&self) -> &[AstNode] { &self.identity } pub fn journal(&self) -> &[AstNode] { &self.journal } pub fn conversation(&self) -> &[AstNode] { &self.conversation } - pub fn conversation_mut(&mut self) -> &mut Vec { &mut self.conversation } + + /// Set or clear a single `memory_scores` entry on an Assistant + /// Branch. Used by the full-matrix scorer to attribute per-memory + /// divergence onto the response. `score = None` removes the key; + /// `Some(s)` inserts/overwrites. + /// + /// Doesn't affect the Branch's token cache: `memory_scores` is a + /// serialized-but-non-tokenizing annotation. No-op (with a debug + /// log) if the index points to a Leaf or a non-Assistant Branch — + /// callers are typically iterating on stale indices and we'd + /// rather skip than panic. + pub fn set_branch_memory_score( + &mut self, + section: Section, + index: usize, + key: &str, + score: Option, + ) { + let nodes = self.section_mut(section); + let Some(node) = nodes.get_mut(index) else { return }; + let AstNode::Branch { role: Role::Assistant, memory_scores, .. } = node + else { return }; + match score { + Some(s) => { memory_scores.insert(key.to_string(), s); } + None => { memory_scores.remove(key); } + } + } pub fn sections(&self) -> [&Vec; 4] { [&self.system, &self.identity, &self.journal, &self.conversation] @@ -1051,8 +1088,14 @@ fn wire_into(node: &AstNode, tokens: &mut Vec, images: &mut Vec) } _ => tokens.extend_from_slice(leaf.token_ids()), }, - AstNode::Branch { token_ids: Some(cached), .. } => { + AstNode::Branch { token_ids: Some(cached), children, .. } => { + // Cached branches still need their image children paired + // up with the vision-block ranges embedded in the cached + // token stream — the cache captures vision tokens but not + // the matching bytes/mime. + let base = tokens.len() as u32; tokens.extend_from_slice(cached); + pair_cached_images(cached, children, base, images); } AstNode::Branch { role, children, token_ids: None, .. } => { tokens.push(tokenizer::IM_START); @@ -1066,6 +1109,101 @@ fn wire_into(node: &AstNode, tokens: &mut Vec, images: &mut Vec) } } +/// Depth-first iterator over Image leaves under a slice of AST nodes. +/// Yields `(bytes, mime)` borrows in document order; doesn't allocate +/// per yield (only a stack of pending nodes). +struct ImageLeaves<'a> { + stack: Vec<&'a AstNode>, +} + +impl<'a> ImageLeaves<'a> { + fn new(nodes: &'a [AstNode]) -> Self { + let mut stack = Vec::with_capacity(nodes.len()); + stack.extend(nodes.iter().rev()); + Self { stack } + } +} + +impl<'a> Iterator for ImageLeaves<'a> { + type Item = (&'a [u8], &'a str); + fn next(&mut self) -> Option { + while let Some(node) = self.stack.pop() { + match node { + AstNode::Leaf(leaf) => { + if let NodeBody::Image { bytes, mime, .. } = leaf.body() { + return Some((bytes, mime)); + } + } + AstNode::Branch { children, .. } => { + self.stack.extend(children.iter().rev()); + } + } + } + None + } +} + +/// Iterator over `(start, end)` token-offset pairs for each +/// `VISION_START..VISION_END` block in a token slice. Panics on an +/// unmatched VISION_START — that's an upstream tokenization bug +/// worth a loud failure. +fn vision_blocks(cached: &[u32]) -> impl Iterator + '_ { + let mut cur = 0; + std::iter::from_fn(move || { + while cur < cached.len() { + if cached[cur] == tokenizer::VISION_START { + let start = cur; + let end_rel = cached[cur..].iter() + .position(|&t| t == tokenizer::VISION_END) + .unwrap_or_else(|| panic!( + "unmatched VISION_START at offset {} in cached branch", + start)); + let end = cur + end_rel + 1; + cur = end; + return Some((start, end)); + } + cur += 1; + } + None + }) +} + +/// For a Branch whose `token_ids` are cached and may contain inlined +/// vision blocks (`VISION_START + IMAGE_PAD*N + VISION_END`), recover +/// the matching image bytes/mime from the children and emit one +/// `WireImage` per vision block with the absolute pad offsets in the +/// parent token stream. +/// +/// The cache stores tokens but not image payloads; the AST stores +/// image payloads in the children but not their post-cache positions. +/// Pair them by zipping the two iterators; mismatched counts panic +/// loudly because that's an AST/cache invariant violation that +/// would otherwise mis-pair images on the wire. +fn pair_cached_images( + cached: &[u32], + children: &[AstNode], + base_offset: u32, + images: &mut Vec, +) { + let mut blocks = vision_blocks(cached); + let mut leaves = ImageLeaves::new(children); + loop { + match (blocks.next(), leaves.next()) { + (Some((s, e)), Some((bytes, mime))) => images.push(WireImage { + bytes: bytes.to_vec(), + mime: mime.to_string(), + pad_start: base_offset + s as u32, + pad_end: base_offset + e as u32, + }), + (None, None) => break, + (Some(_), None) => panic!( + "cached branch has more vision blocks than image children"), + (None, Some(_)) => panic!( + "cached branch has fewer vision blocks than image children"), + } + } +} + pub fn memory_key(node: &AstNode) -> Option<&str> { match node { AstNode::Leaf(leaf) => match leaf.body() { @@ -1224,8 +1362,13 @@ impl ContextState { } _ => buf.extend_from_slice(leaf.token_ids()), }, - AstNode::Branch { token_ids: Some(cached), .. } => { + AstNode::Branch { token_ids: Some(cached), children, .. } => { + // Same fix as wire_into's cached arm: the cache + // holds vision tokens but not the matching bytes, + // so walk children to recover them. + let base = buf.len() as u32; buf.extend_from_slice(cached); + pair_cached_images(cached, children, base, images); } AstNode::Branch { role, children, token_ids: None, .. } => { buf.push(tokenizer::IM_START); diff --git a/src/subconscious/learn.rs b/src/subconscious/learn.rs index feb209c..129e26b 100644 --- a/src/subconscious/learn.rs +++ b/src/subconscious/learn.rs @@ -240,25 +240,23 @@ pub async fn score_memories( vec![0.0; baseline.len()] } }; - // Write this memory's scores to the live AST nodes + // Write this memory's scores to the live AST nodes via the + // focused setter — keeps the AST mutation surface narrow. { let mut ctx = agent.context.lock().await; let mut set_count = 0; for (resp_idx, &idx) in response_indices.iter().enumerate() { - if idx >= ctx.conversation().len() { continue; } - let node = &mut ctx.conversation_mut()[idx]; - if let AstNode::Branch { - role: Role::Assistant, memory_scores, .. - } = node { - if let Some(&score) = row.get(resp_idx) { - if score > 0.01 { - memory_scores.insert(key.clone(), score); - set_count += 1; - } else { - memory_scores.remove(key.as_str()); - } - } + let Some(&score) = row.get(resp_idx) else { continue }; + let normalized = if score > 0.01 { Some(score) } else { None }; + ctx.set_branch_memory_score( + crate::agent::context::Section::Conversation, + idx, + &key, + normalized, + ); + if normalized.is_some() { + set_count += 1; } }