forked from kent/consciousness
context: tighten the Branch token-cache invariant
Two pieces around the cache that landed when Branch nodes started
holding `token_ids: Some(server_authoritative_stream)`:
1. wire_into / wire_chunks now pair cached vision blocks with their
child Image leaves. Previously the cached-branch arm spliced the
cache verbatim and didn't recurse for images, so a Branch whose
cache contained `VISION_START..VISION_END` blocks would emit those
tokens with no matching `WireImage` push — leading to a panic
downstream when `pair_images_to_ranges` tried to attach the
missing image. New `pair_cached_images` walks the children
depth-first for image leaves and zips them against
`vision_blocks(cache)` to emit correctly-offset entries; mismatched
counts panic loudly because that's an AST/cache invariant
violation that would otherwise mis-pair on the wire.
2. `conversation_mut() -> &mut Vec<AstNode>` was the one public
escape hatch that let callers reach into a Branch's children and
mutate them without invalidating the cached token stream. Removed
in favor of a focused `set_branch_memory_score(section, index,
key, score)` for the only legitimate use we had today (the
full-matrix scorer writing per-memory divergence onto the
Assistant Branch). Updated the lone caller in subconscious/learn.
Documented the invariants explicitly on `ContextState`: every
`Leaf.token_ids` matches `body.compute_token_ids()`, and every
`Branch { token_ids: Some(_) }` is a faithful walk of its children.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
006b99bdac
commit
c2433c1773
2 changed files with 160 additions and 19 deletions
|
|
@ -154,8 +154,19 @@ pub enum AstNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The context window: four sections as Vec<AstNode>.
|
/// The context window: four sections as Vec<AstNode>.
|
||||||
/// All mutation goes through ContextState methods to maintain the invariant
|
///
|
||||||
/// that token_ids on every leaf matches its rendered text.
|
/// All mutation MUST go through `ContextState`'s public methods. Two
|
||||||
|
/// invariants ride on this:
|
||||||
|
/// 1. Every `Leaf.token_ids` matches its `body.compute_token_ids()`.
|
||||||
|
/// 2. For every `Branch { token_ids: Some(cached), .. }`, the cached
|
||||||
|
/// token stream matches what `wire_into` would produce by walking
|
||||||
|
/// `children` from scratch. Any mutation that touches a Branch's
|
||||||
|
/// children — directly or via a descendant — must clear the
|
||||||
|
/// Branch's `token_ids` so it gets recomputed on next wire-out.
|
||||||
|
///
|
||||||
|
/// The `&mut Vec<AstNode>` escape hatches are intentionally NOT
|
||||||
|
/// exposed; if you find yourself wanting one, add a focused method
|
||||||
|
/// here that maintains the invariants.
|
||||||
pub struct ContextState {
|
pub struct ContextState {
|
||||||
system: Vec<AstNode>,
|
system: Vec<AstNode>,
|
||||||
identity: Vec<AstNode>,
|
identity: Vec<AstNode>,
|
||||||
|
|
@ -966,7 +977,33 @@ impl ContextState {
|
||||||
pub fn identity(&self) -> &[AstNode] { &self.identity }
|
pub fn identity(&self) -> &[AstNode] { &self.identity }
|
||||||
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
pub fn journal(&self) -> &[AstNode] { &self.journal }
|
||||||
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
pub fn conversation(&self) -> &[AstNode] { &self.conversation }
|
||||||
pub fn conversation_mut(&mut self) -> &mut Vec<AstNode> { &mut self.conversation }
|
|
||||||
|
/// Set or clear a single `memory_scores` entry on an Assistant
|
||||||
|
/// Branch. Used by the full-matrix scorer to attribute per-memory
|
||||||
|
/// divergence onto the response. `score = None` removes the key;
|
||||||
|
/// `Some(s)` inserts/overwrites.
|
||||||
|
///
|
||||||
|
/// Doesn't affect the Branch's token cache: `memory_scores` is a
|
||||||
|
/// serialized-but-non-tokenizing annotation. No-op (with a debug
|
||||||
|
/// log) if the index points to a Leaf or a non-Assistant Branch —
|
||||||
|
/// callers are typically iterating on stale indices and we'd
|
||||||
|
/// rather skip than panic.
|
||||||
|
pub fn set_branch_memory_score(
|
||||||
|
&mut self,
|
||||||
|
section: Section,
|
||||||
|
index: usize,
|
||||||
|
key: &str,
|
||||||
|
score: Option<f64>,
|
||||||
|
) {
|
||||||
|
let nodes = self.section_mut(section);
|
||||||
|
let Some(node) = nodes.get_mut(index) else { return };
|
||||||
|
let AstNode::Branch { role: Role::Assistant, memory_scores, .. } = node
|
||||||
|
else { return };
|
||||||
|
match score {
|
||||||
|
Some(s) => { memory_scores.insert(key.to_string(), s); }
|
||||||
|
None => { memory_scores.remove(key); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sections(&self) -> [&Vec<AstNode>; 4] {
|
pub fn sections(&self) -> [&Vec<AstNode>; 4] {
|
||||||
[&self.system, &self.identity, &self.journal, &self.conversation]
|
[&self.system, &self.identity, &self.journal, &self.conversation]
|
||||||
|
|
@ -1051,8 +1088,14 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
|
||||||
}
|
}
|
||||||
_ => tokens.extend_from_slice(leaf.token_ids()),
|
_ => tokens.extend_from_slice(leaf.token_ids()),
|
||||||
},
|
},
|
||||||
AstNode::Branch { token_ids: Some(cached), .. } => {
|
AstNode::Branch { token_ids: Some(cached), children, .. } => {
|
||||||
|
// Cached branches still need their image children paired
|
||||||
|
// up with the vision-block ranges embedded in the cached
|
||||||
|
// token stream — the cache captures vision tokens but not
|
||||||
|
// the matching bytes/mime.
|
||||||
|
let base = tokens.len() as u32;
|
||||||
tokens.extend_from_slice(cached);
|
tokens.extend_from_slice(cached);
|
||||||
|
pair_cached_images(cached, children, base, images);
|
||||||
}
|
}
|
||||||
AstNode::Branch { role, children, token_ids: None, .. } => {
|
AstNode::Branch { role, children, token_ids: None, .. } => {
|
||||||
tokens.push(tokenizer::IM_START);
|
tokens.push(tokenizer::IM_START);
|
||||||
|
|
@ -1066,6 +1109,101 @@ fn wire_into(node: &AstNode, tokens: &mut Vec<u32>, images: &mut Vec<WireImage>)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Depth-first iterator over Image leaves under a slice of AST nodes.
|
||||||
|
/// Yields `(bytes, mime)` borrows in document order; doesn't allocate
|
||||||
|
/// per yield (only a stack of pending nodes).
|
||||||
|
struct ImageLeaves<'a> {
|
||||||
|
stack: Vec<&'a AstNode>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ImageLeaves<'a> {
|
||||||
|
fn new(nodes: &'a [AstNode]) -> Self {
|
||||||
|
let mut stack = Vec::with_capacity(nodes.len());
|
||||||
|
stack.extend(nodes.iter().rev());
|
||||||
|
Self { stack }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for ImageLeaves<'a> {
|
||||||
|
type Item = (&'a [u8], &'a str);
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
while let Some(node) = self.stack.pop() {
|
||||||
|
match node {
|
||||||
|
AstNode::Leaf(leaf) => {
|
||||||
|
if let NodeBody::Image { bytes, mime, .. } = leaf.body() {
|
||||||
|
return Some((bytes, mime));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AstNode::Branch { children, .. } => {
|
||||||
|
self.stack.extend(children.iter().rev());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterator over `(start, end)` token-offset pairs for each
|
||||||
|
/// `VISION_START..VISION_END` block in a token slice. Panics on an
|
||||||
|
/// unmatched VISION_START — that's an upstream tokenization bug
|
||||||
|
/// worth a loud failure.
|
||||||
|
fn vision_blocks(cached: &[u32]) -> impl Iterator<Item = (usize, usize)> + '_ {
|
||||||
|
let mut cur = 0;
|
||||||
|
std::iter::from_fn(move || {
|
||||||
|
while cur < cached.len() {
|
||||||
|
if cached[cur] == tokenizer::VISION_START {
|
||||||
|
let start = cur;
|
||||||
|
let end_rel = cached[cur..].iter()
|
||||||
|
.position(|&t| t == tokenizer::VISION_END)
|
||||||
|
.unwrap_or_else(|| panic!(
|
||||||
|
"unmatched VISION_START at offset {} in cached branch",
|
||||||
|
start));
|
||||||
|
let end = cur + end_rel + 1;
|
||||||
|
cur = end;
|
||||||
|
return Some((start, end));
|
||||||
|
}
|
||||||
|
cur += 1;
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// For a Branch whose `token_ids` are cached and may contain inlined
|
||||||
|
/// vision blocks (`VISION_START + IMAGE_PAD*N + VISION_END`), recover
|
||||||
|
/// the matching image bytes/mime from the children and emit one
|
||||||
|
/// `WireImage` per vision block with the absolute pad offsets in the
|
||||||
|
/// parent token stream.
|
||||||
|
///
|
||||||
|
/// The cache stores tokens but not image payloads; the AST stores
|
||||||
|
/// image payloads in the children but not their post-cache positions.
|
||||||
|
/// Pair them by zipping the two iterators; mismatched counts panic
|
||||||
|
/// loudly because that's an AST/cache invariant violation that
|
||||||
|
/// would otherwise mis-pair images on the wire.
|
||||||
|
fn pair_cached_images(
|
||||||
|
cached: &[u32],
|
||||||
|
children: &[AstNode],
|
||||||
|
base_offset: u32,
|
||||||
|
images: &mut Vec<WireImage>,
|
||||||
|
) {
|
||||||
|
let mut blocks = vision_blocks(cached);
|
||||||
|
let mut leaves = ImageLeaves::new(children);
|
||||||
|
loop {
|
||||||
|
match (blocks.next(), leaves.next()) {
|
||||||
|
(Some((s, e)), Some((bytes, mime))) => images.push(WireImage {
|
||||||
|
bytes: bytes.to_vec(),
|
||||||
|
mime: mime.to_string(),
|
||||||
|
pad_start: base_offset + s as u32,
|
||||||
|
pad_end: base_offset + e as u32,
|
||||||
|
}),
|
||||||
|
(None, None) => break,
|
||||||
|
(Some(_), None) => panic!(
|
||||||
|
"cached branch has more vision blocks than image children"),
|
||||||
|
(None, Some(_)) => panic!(
|
||||||
|
"cached branch has fewer vision blocks than image children"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn memory_key(node: &AstNode) -> Option<&str> {
|
pub fn memory_key(node: &AstNode) -> Option<&str> {
|
||||||
match node {
|
match node {
|
||||||
AstNode::Leaf(leaf) => match leaf.body() {
|
AstNode::Leaf(leaf) => match leaf.body() {
|
||||||
|
|
@ -1224,8 +1362,13 @@ impl ContextState {
|
||||||
}
|
}
|
||||||
_ => buf.extend_from_slice(leaf.token_ids()),
|
_ => buf.extend_from_slice(leaf.token_ids()),
|
||||||
},
|
},
|
||||||
AstNode::Branch { token_ids: Some(cached), .. } => {
|
AstNode::Branch { token_ids: Some(cached), children, .. } => {
|
||||||
|
// Same fix as wire_into's cached arm: the cache
|
||||||
|
// holds vision tokens but not the matching bytes,
|
||||||
|
// so walk children to recover them.
|
||||||
|
let base = buf.len() as u32;
|
||||||
buf.extend_from_slice(cached);
|
buf.extend_from_slice(cached);
|
||||||
|
pair_cached_images(cached, children, base, images);
|
||||||
}
|
}
|
||||||
AstNode::Branch { role, children, token_ids: None, .. } => {
|
AstNode::Branch { role, children, token_ids: None, .. } => {
|
||||||
buf.push(tokenizer::IM_START);
|
buf.push(tokenizer::IM_START);
|
||||||
|
|
|
||||||
|
|
@ -240,25 +240,23 @@ pub async fn score_memories(
|
||||||
vec![0.0; baseline.len()]
|
vec![0.0; baseline.len()]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// Write this memory's scores to the live AST nodes
|
// Write this memory's scores to the live AST nodes via the
|
||||||
|
// focused setter — keeps the AST mutation surface narrow.
|
||||||
{
|
{
|
||||||
let mut ctx = agent.context.lock().await;
|
let mut ctx = agent.context.lock().await;
|
||||||
let mut set_count = 0;
|
let mut set_count = 0;
|
||||||
|
|
||||||
for (resp_idx, &idx) in response_indices.iter().enumerate() {
|
for (resp_idx, &idx) in response_indices.iter().enumerate() {
|
||||||
if idx >= ctx.conversation().len() { continue; }
|
let Some(&score) = row.get(resp_idx) else { continue };
|
||||||
let node = &mut ctx.conversation_mut()[idx];
|
let normalized = if score > 0.01 { Some(score) } else { None };
|
||||||
if let AstNode::Branch {
|
ctx.set_branch_memory_score(
|
||||||
role: Role::Assistant, memory_scores, ..
|
crate::agent::context::Section::Conversation,
|
||||||
} = node {
|
idx,
|
||||||
if let Some(&score) = row.get(resp_idx) {
|
&key,
|
||||||
if score > 0.01 {
|
normalized,
|
||||||
memory_scores.insert(key.clone(), score);
|
);
|
||||||
set_count += 1;
|
if normalized.is_some() {
|
||||||
} else {
|
set_count += 1;
|
||||||
memory_scores.remove(key.as_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue