forked from kent/consciousness
context: tighten the Branch token-cache invariant
Two pieces around the cache that landed when Branch nodes started
holding `token_ids: Some(server_authoritative_stream)`:
1. wire_into / wire_chunks now pair cached vision blocks with their
child Image leaves. Previously the cached-branch arm spliced the
cache verbatim and didn't recurse for images, so a Branch whose
cache contained `VISION_START..VISION_END` blocks would emit those
tokens with no matching `WireImage` push — leading to a panic
downstream when `pair_images_to_ranges` tried to attach the
missing image. New `pair_cached_images` walks the children
depth-first for image leaves and zips them against
`vision_blocks(cache)` to emit correctly-offset entries; mismatched
counts panic loudly because that's an AST/cache invariant
violation that would otherwise mis-pair on the wire.
2. `conversation_mut() -> &mut Vec<AstNode>` was the one public
escape hatch that let callers reach into a Branch's children and
mutate them without invalidating the cached token stream. Removed
in favor of a focused `set_branch_memory_score(section, index,
key, score)` for the only legitimate use we had today (the
full-matrix scorer writing per-memory divergence onto the
Assistant Branch). Updated the lone caller in subconscious/learn.
Documented the invariants explicitly on `ContextState`: every
`Leaf.token_ids` matches `body.compute_token_ids()`, and every
`Branch { token_ids: Some(_) }` is a faithful walk of its children.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
006b99bdac
commit
c2433c1773
2 changed files with 160 additions and 19 deletions
|
|
@ -154,8 +154,19 @@ pub enum AstNode {
|
|||
}
|
||||
|
||||
/// The context window: four sections as Vec<AstNode>.
|
||||
/// All mutation goes through ContextState methods to maintain the invariant
|
||||
/// that token_ids on every leaf matches its rendered text.
|
||||
///
|
||||
/// All mutation MUST go through `ContextState`'s public methods. Two
|
||||
/// invariants ride on this:
|
||||
/// 1. Every `Leaf.token_ids` matches its `body.compute_token_ids()`.
|
||||
/// 2. For every `Branch { token_ids: Some(cached), .. }`, the cached
|
||||
/// token stream matches what `wire_into` would produce by walking
|
||||
/// `children` from scratch. Any mutation that touches a Branch's
|
||||
/// children — directly or via a descendant — must clear the
|
||||
/// Branch's `token_ids` so it gets recomputed on next wire-out.
|
||||
///
|
||||
/// The `&mut Vec<AstNode>` escape hatches are intentionally NOT
|
||||
/// exposed; if you find yourself wanting one, add a focused method
|
||||
/// here that maintains the invariants.
|
||||
pub struct ContextState {
|
||||
system: Vec<AstNode>,
|
||||
identity: Vec<AstNode>,
|
||||
|
|
@ -966,7 +977,33 @@ impl ContextState {
|
|||
pub fn identity(&self) -> &[AstNode] { &self.identity }
|
||||
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
||||
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
||||
pub fn conversation_mut(&mut self) -> &mut Vec<AstNode> { &mut self.conversation }
|
||||
|
||||
/// Set or clear a single `memory_scores` entry on an Assistant
|
||||
/// Branch. Used by the full-matrix scorer to attribute per-memory
|
||||
/// divergence onto the response. `score = None` removes the key;
|
||||
/// `Some(s)` inserts/overwrites.
|
||||
///
|
||||
/// Doesn't affect the Branch's token cache: `memory_scores` is a
|
||||
/// serialized-but-non-tokenizing annotation. No-op (with a debug
|
||||
/// log) if the index points to a Leaf or a non-Assistant Branch —
|
||||
/// callers are typically iterating on stale indices and we'd
|
||||
/// rather skip than panic.
|
||||
pub fn set_branch_memory_score(
|
||||
&mut self,
|
||||
section: Section,
|
||||
index: usize,
|
||||
key: &str,
|
||||
score: Option<f64>,
|
||||
) {
|
||||
let nodes = self.section_mut(section);
|
||||
let Some(node) = nodes.get_mut(index) else { return };
|
||||
let AstNode::Branch { role: Role::Assistant, memory_scores, .. } = node
|
||||
else { return };
|
||||
match score {
|
||||
Some(s) => { memory_scores.insert(key.to_string(), s); }
|
||||
None => { memory_scores.remove(key); }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sections(&self) -> [&Vec<AstNode>; 4] {
|
||||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||
|
|
@ -1051,8 +1088,14 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
|
|||
}
|
||||
_ => tokens.extend_from_slice(leaf.token_ids()),
|
||||
},
|
||||
AstNode::Branch { token_ids: Some(cached), .. } => {
|
||||
AstNode::Branch { token_ids: Some(cached), children, .. } => {
|
||||
// Cached branches still need their image children paired
|
||||
// up with the vision-block ranges embedded in the cached
|
||||
// token stream — the cache captures vision tokens but not
|
||||
// the matching bytes/mime.
|
||||
let base = tokens.len() as u32;
|
||||
tokens.extend_from_slice(cached);
|
||||
pair_cached_images(cached, children, base, images);
|
||||
}
|
||||
AstNode::Branch { role, children, token_ids: None, .. } => {
|
||||
tokens.push(tokenizer::IM_START);
|
||||
|
|
@ -1066,6 +1109,101 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
|
|||
}
|
||||
}
|
||||
|
||||
/// Depth-first iterator over Image leaves under a slice of AST nodes.
|
||||
/// Yields `(bytes, mime)` borrows in document order; doesn't allocate
|
||||
/// per yield (only a stack of pending nodes).
|
||||
struct ImageLeaves<'a> {
|
||||
stack: Vec<&'a AstNode>,
|
||||
}
|
||||
|
||||
impl<'a> ImageLeaves<'a> {
|
||||
fn new(nodes: &'a [AstNode]) -> Self {
|
||||
let mut stack = Vec::with_capacity(nodes.len());
|
||||
stack.extend(nodes.iter().rev());
|
||||
Self { stack }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for ImageLeaves<'a> {
|
||||
type Item = (&'a [u8], &'a str);
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(node) = self.stack.pop() {
|
||||
match node {
|
||||
AstNode::Leaf(leaf) => {
|
||||
if let NodeBody::Image { bytes, mime, .. } = leaf.body() {
|
||||
return Some((bytes, mime));
|
||||
}
|
||||
}
|
||||
AstNode::Branch { children, .. } => {
|
||||
self.stack.extend(children.iter().rev());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator over `(start, end)` token-offset pairs for each
|
||||
/// `VISION_START..VISION_END` block in a token slice. Panics on an
|
||||
/// unmatched VISION_START — that's an upstream tokenization bug
|
||||
/// worth a loud failure.
|
||||
fn vision_blocks(cached: &[u32]) -> impl Iterator<Item = (usize, usize)> + '_ {
|
||||
let mut cur = 0;
|
||||
std::iter::from_fn(move || {
|
||||
while cur < cached.len() {
|
||||
if cached[cur] == tokenizer::VISION_START {
|
||||
let start = cur;
|
||||
let end_rel = cached[cur..].iter()
|
||||
.position(|&t| t == tokenizer::VISION_END)
|
||||
.unwrap_or_else(|| panic!(
|
||||
"unmatched VISION_START at offset {} in cached branch",
|
||||
start));
|
||||
let end = cur + end_rel + 1;
|
||||
cur = end;
|
||||
return Some((start, end));
|
||||
}
|
||||
cur += 1;
|
||||
}
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// For a Branch whose `token_ids` are cached and may contain inlined
|
||||
/// vision blocks (`VISION_START + IMAGE_PAD*N + VISION_END`), recover
|
||||
/// the matching image bytes/mime from the children and emit one
|
||||
/// `WireImage` per vision block with the absolute pad offsets in the
|
||||
/// parent token stream.
|
||||
///
|
||||
/// The cache stores tokens but not image payloads; the AST stores
|
||||
/// image payloads in the children but not their post-cache positions.
|
||||
/// Pair them by zipping the two iterators; mismatched counts panic
|
||||
/// loudly because that's an AST/cache invariant violation that
|
||||
/// would otherwise mis-pair images on the wire.
|
||||
fn pair_cached_images(
|
||||
cached: &[u32],
|
||||
children: &[AstNode],
|
||||
base_offset: u32,
|
||||
images: &mut Vec<WireImage>,
|
||||
) {
|
||||
let mut blocks = vision_blocks(cached);
|
||||
let mut leaves = ImageLeaves::new(children);
|
||||
loop {
|
||||
match (blocks.next(), leaves.next()) {
|
||||
(Some((s, e)), Some((bytes, mime))) => images.push(WireImage {
|
||||
bytes: bytes.to_vec(),
|
||||
mime: mime.to_string(),
|
||||
pad_start: base_offset + s as u32,
|
||||
pad_end: base_offset + e as u32,
|
||||
}),
|
||||
(None, None) => break,
|
||||
(Some(_), None) => panic!(
|
||||
"cached branch has more vision blocks than image children"),
|
||||
(None, Some(_)) => panic!(
|
||||
"cached branch has fewer vision blocks than image children"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn memory_key(node: &AstNode) -> Option<&str> {
|
||||
match node {
|
||||
AstNode::Leaf(leaf) => match leaf.body() {
|
||||
|
|
@ -1224,8 +1362,13 @@ impl ContextState {
|
|||
}
|
||||
_ => buf.extend_from_slice(leaf.token_ids()),
|
||||
},
|
||||
AstNode::Branch { token_ids: Some(cached), .. } => {
|
||||
AstNode::Branch { token_ids: Some(cached), children, .. } => {
|
||||
// Same fix as wire_into's cached arm: the cache
|
||||
// holds vision tokens but not the matching bytes,
|
||||
// so walk children to recover them.
|
||||
let base = buf.len() as u32;
|
||||
buf.extend_from_slice(cached);
|
||||
pair_cached_images(cached, children, base, images);
|
||||
}
|
||||
AstNode::Branch { role, children, token_ids: None, .. } => {
|
||||
buf.push(tokenizer::IM_START);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue