Fast startup: only retokenize tail of conversation log

restore_from_log reads the full log but walks backwards from the tail,
retokenizing each node as it goes. Stops when conversation budget is
full. Only the nodes that fit get pushed into context.

Added AstNode::retokenize() — recomputes token_ids on all leaves
after deserialization (serde skip means they're empty).

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-09 13:06:19 -04:00
parent 6ec0e1c766
commit 7da3efc5df
2 changed files with 40 additions and 5 deletions

View file

@ -296,6 +296,23 @@ impl AstNode {
// -- Builder -------------------------------------------------------------- // -- Builder --------------------------------------------------------------
pub fn retokenize(self) -> Self {
match self {
Self::Leaf(leaf) => {
let token_ids = if leaf.body.is_prompt_visible() {
tokenizer::encode(&leaf.body.render())
} else {
vec![]
};
Self::Leaf(NodeLeaf { token_ids, ..leaf })
}
Self::Branch { role, children } => Self::Branch {
role,
children: children.into_iter().map(|c| c.retokenize()).collect(),
},
}
}
pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self { pub fn with_timestamp(mut self, ts: DateTime<Utc>) -> Self {
match &mut self { match &mut self {
Self::Leaf(leaf) => leaf.timestamp = Some(ts), Self::Leaf(leaf) => leaf.timestamp = Some(ts),

View file

@ -568,7 +568,7 @@ impl Agent {
} }
pub async fn restore_from_log(&self) -> bool { pub async fn restore_from_log(&self) -> bool {
let nodes = { let all_nodes = {
let ctx = self.context.lock().await; let ctx = self.context.lock().await;
match &ctx.conversation_log { match &ctx.conversation_log {
Some(log) => match log.read_nodes(64 * 1024 * 1024) { Some(log) => match log.read_nodes(64 * 1024 * 1024) {
@ -579,17 +579,35 @@ impl Agent {
} }
}; };
// Walk backwards from the tail, retokenize, stop at budget
let budget = context::context_budget_tokens();
let fixed = {
let ctx = self.context.lock().await;
ctx.system().iter().chain(ctx.identity().iter())
.map(|n| n.tokens()).sum::<usize>()
};
let conv_budget = budget.saturating_sub(fixed);
let mut kept = Vec::new();
let mut total = 0;
for node in all_nodes.into_iter().rev() {
let node = node.retokenize();
let tok = node.tokens();
if total + tok > conv_budget && !kept.is_empty() { break; }
total += tok;
kept.push(node);
}
kept.reverse();
{ {
let mut ctx = self.context.lock().await; let mut ctx = self.context.lock().await;
ctx.clear(Section::Conversation); ctx.clear(Section::Conversation);
// Push without logging — these are already in the log for node in kept {
for node in nodes {
ctx.push_no_log(Section::Conversation, node); ctx.push_no_log(Section::Conversation, node);
} }
} }
self.compact().await; self.compact().await;
let mut st = self.state.lock().await; self.state.lock().await.last_prompt_tokens = self.context.lock().await.tokens() as u32;
st.last_prompt_tokens = self.context.lock().await.tokens() as u32;
true true
} }