context: tighten the Branch token-cache invariant

Two pieces around the cache that landed when Branch nodes started
holding `token_ids: Some(server_authoritative_stream)`:

1. wire_into / wire_chunks now pair cached vision blocks with their
   child Image leaves. Previously the cached-branch arm spliced the
   cache verbatim and didn't recurse for images, so a Branch whose
   cache contained `VISION_START..VISION_END` blocks would emit those
   tokens with no matching `WireImage` push — leading to a panic
   downstream when `pair_images_to_ranges` tried to attach the
   missing image. New `pair_cached_images` walks the children
   depth-first for image leaves and zips them against
   `vision_blocks(cache)` to emit correctly-offset entries; mismatched
   counts panic loudly because that's an AST/cache invariant
   violation that would otherwise mis-pair on the wire.

2. `conversation_mut() -> &mut Vec<AstNode>` was the one public
   escape hatch that let callers reach into a Branch's children and
   mutate them without invalidating the cached token stream. Removed
   in favor of a focused `set_branch_memory_score(section, index,
   key, score)` for the only legitimate use we had today (the
   full-matrix scorer writing per-memory divergence onto the
   Assistant Branch). Updated the lone caller in subconscious/learn.

Documented the invariants explicitly on `ContextState`: every
`Leaf.token_ids` matches `body.compute_token_ids()`, and every
`Branch { token_ids: Some(_) }` is a faithful walk of its children.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-24 23:15:55 -04:00
commit c2433c1773
2 changed files with 160 additions and 19 deletions

View file

@ -154,8 +154,19 @@ pub enum AstNode {
}
/// The context window: four sections as Vec<AstNode>.
/// All mutation goes through ContextState methods to maintain the invariant
/// that token_ids on every leaf matches its rendered text.
///
/// All mutation MUST go through `ContextState`'s public methods. Two
/// invariants ride on this:
/// 1. Every `Leaf.token_ids` matches its `body.compute_token_ids()`.
/// 2. For every `Branch { token_ids: Some(cached), .. }`, the cached
/// token stream matches what `wire_into` would produce by walking
/// `children` from scratch. Any mutation that touches a Branch's
/// children — directly or via a descendant — must clear the
/// Branch's `token_ids` so it gets recomputed on next wire-out.
///
/// The `&mut Vec<AstNode>` escape hatches are intentionally NOT
/// exposed; if you find yourself wanting one, add a focused method
/// here that maintains the invariants.
pub struct ContextState {
system: Vec<AstNode>,
identity: Vec<AstNode>,
@ -966,7 +977,33 @@ impl ContextState {
pub fn identity(&self) -> &[AstNode] { &self.identity }
pub fn journal(&self) -> &[AstNode] { &self.journal }
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
pub fn conversation_mut(&mut self) -> &mut Vec<AstNode> { &mut self.conversation }
/// Set or clear a single `memory_scores` entry on an Assistant
/// Branch. Used by the full-matrix scorer to attribute per-memory
/// divergence onto the response. `score = None` removes the key;
/// `Some(s)` inserts/overwrites.
///
/// Doesn't affect the Branch's token cache: `memory_scores` is a
/// serialized-but-non-tokenizing annotation. No-op (with a debug
/// log) if the index points to a Leaf or a non-Assistant Branch —
/// callers are typically iterating on stale indices and we'd
/// rather skip than panic.
pub fn set_branch_memory_score(
&mut self,
section: Section,
index: usize,
key: &str,
score: Option<f64>,
) {
let nodes = self.section_mut(section);
let Some(node) = nodes.get_mut(index) else { return };
let AstNode::Branch { role: Role::Assistant, memory_scores, .. } = node
else { return };
match score {
Some(s) => { memory_scores.insert(key.to_string(), s); }
None => { memory_scores.remove(key); }
}
}
pub fn sections(&self) -> [&Vec<AstNode>; 4] {
[&self.system, &self.identity, &self.journal, &self.conversation]
@ -1051,8 +1088,14 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
}
_ => tokens.extend_from_slice(leaf.token_ids()),
},
AstNode::Branch { token_ids: Some(cached), .. } => {
AstNode::Branch { token_ids: Some(cached), children, .. } => {
// Cached branches still need their image children paired
// up with the vision-block ranges embedded in the cached
// token stream — the cache captures vision tokens but not
// the matching bytes/mime.
let base = tokens.len() as u32;
tokens.extend_from_slice(cached);
pair_cached_images(cached, children, base, images);
}
AstNode::Branch { role, children, token_ids: None, .. } => {
tokens.push(tokenizer::IM_START);
@ -1066,6 +1109,101 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
}
}
/// Depth-first iterator over Image leaves under a slice of AST nodes.
/// Yields `(bytes, mime)` borrows in document order; doesn't allocate
/// per yield (only a stack of pending nodes).
struct ImageLeaves<'a> {
stack: Vec<&'a AstNode>,
}
impl<'a> ImageLeaves<'a> {
fn new(nodes: &'a [AstNode]) -> Self {
let mut stack = Vec::with_capacity(nodes.len());
stack.extend(nodes.iter().rev());
Self { stack }
}
}
impl<'a> Iterator for ImageLeaves<'a> {
type Item = (&'a [u8], &'a str);
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.stack.pop() {
match node {
AstNode::Leaf(leaf) => {
if let NodeBody::Image { bytes, mime, .. } = leaf.body() {
return Some((bytes, mime));
}
}
AstNode::Branch { children, .. } => {
self.stack.extend(children.iter().rev());
}
}
}
None
}
}
/// Iterator over `(start, end)` token-offset pairs for each
/// `VISION_START..VISION_END` block in a token slice. Panics on an
/// unmatched VISION_START — that's an upstream tokenization bug
/// worth a loud failure.
fn vision_blocks(cached: &[u32]) -> impl Iterator<Item = (usize, usize)> + '_ {
let mut cur = 0;
std::iter::from_fn(move || {
while cur < cached.len() {
if cached[cur] == tokenizer::VISION_START {
let start = cur;
let end_rel = cached[cur..].iter()
.position(|&t| t == tokenizer::VISION_END)
.unwrap_or_else(|| panic!(
"unmatched VISION_START at offset {} in cached branch",
start));
let end = cur + end_rel + 1;
cur = end;
return Some((start, end));
}
cur += 1;
}
None
})
}
/// For a Branch whose `token_ids` are cached and may contain inlined
/// vision blocks (`VISION_START + IMAGE_PAD*N + VISION_END`), recover
/// the matching image bytes/mime from the children and emit one
/// `WireImage` per vision block with the absolute pad offsets in the
/// parent token stream.
///
/// The cache stores tokens but not image payloads; the AST stores
/// image payloads in the children but not their post-cache positions.
/// Pair them by zipping the two iterators; mismatched counts panic
/// loudly because that's an AST/cache invariant violation that
/// would otherwise mis-pair images on the wire.
fn pair_cached_images(
cached: &[u32],
children: &[AstNode],
base_offset: u32,
images: &mut Vec<WireImage>,
) {
let mut blocks = vision_blocks(cached);
let mut leaves = ImageLeaves::new(children);
loop {
match (blocks.next(), leaves.next()) {
(Some((s, e)), Some((bytes, mime))) => images.push(WireImage {
bytes: bytes.to_vec(),
mime: mime.to_string(),
pad_start: base_offset + s as u32,
pad_end: base_offset + e as u32,
}),
(None, None) => break,
(Some(_), None) => panic!(
"cached branch has more vision blocks than image children"),
(None, Some(_)) => panic!(
"cached branch has fewer vision blocks than image children"),
}
}
}
pub fn memory_key(node: &AstNode) -> Option<&str> {
match node {
AstNode::Leaf(leaf) => match leaf.body() {
@ -1224,8 +1362,13 @@ impl ContextState {
}
_ => buf.extend_from_slice(leaf.token_ids()),
},
AstNode::Branch { token_ids: Some(cached), .. } => {
AstNode::Branch { token_ids: Some(cached), children, .. } => {
// Same fix as wire_into's cached arm: the cache
// holds vision tokens but not the matching bytes,
// so walk children to recover them.
let base = buf.len() as u32;
buf.extend_from_slice(cached);
pair_cached_images(cached, children, base, images);
}
AstNode::Branch { role, children, token_ids: None, .. } => {
buf.push(tokenizer::IM_START);