diff --git a/src/agent/context.rs b/src/agent/context.rs index 38127d5..948e9f2 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -934,6 +934,53 @@ pub fn is_memory_node(node: &AstNode) -> bool { matches!(node, AstNode::Leaf(leaf) if matches!(leaf.body(), NodeBody::Memory { .. })) } +pub fn is_assistant(node: &AstNode) -> bool { + matches!(node, AstNode::Branch { role: Role::Assistant, .. }) +} + +/// Concatenate the text of a Branch's Leaf children — what the model +/// actually produced on that turn (Content + Thinking + ToolCall name). +pub fn render_branch_text(children: &[AstNode]) -> String { + children.iter() + .filter_map(|c| match c { + AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()), + _ => None, + }) + .collect::>() + .join("") +} + +/// Render the last `max_msgs` user/assistant branches before `idx` as a +/// review-friendly string with `[user]` / `[assistant]` markers. +pub fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String { + let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs); + for i in (0..idx).rev() { + if picked.len() >= max_msgs { break; } + if let AstNode::Branch { role, .. } = &entries[i] { + if matches!(role, Role::User | Role::Assistant) { + picked.push(&entries[i]); + } + } + } + picked.reverse(); + + let mut out = String::new(); + for node in picked { + if let AstNode::Branch { role, children, .. } = node { + let marker = match role { + Role::User => "[user]", + Role::Assistant => "[assistant]", + _ => continue, + }; + out.push_str(marker); + out.push('\n'); + out.push_str(render_branch_text(children).trim()); + out.push_str("\n\n"); + } + } + out.trim_end().to_string() +} + impl ContextState { /// Assemble the prompt in wire form: token stream with a single /// `<|image_pad|>` per image (vLLM expands back to N), plus the list diff --git a/src/subconscious/generate.rs b/src/subconscious/generate.rs new file mode 100644 index 0000000..44f967a --- /dev/null +++ b/src/subconscious/generate.rs @@ -0,0 +1,46 @@ +// generate.rs — Continuation generation for scoring / comparison flows. +// +// Shared by the finetune pipeline (learn.rs) and the compare screen: +// given a context prefix and a skip predicate, generate what the model +// would say as the next assistant turn. + +use crate::agent::api::{ApiClient, SamplingParams, StreamToken}; +use crate::agent::context::{AstNode, ContextState}; +use crate::agent::tokenizer; + +/// Generate an assistant continuation from the context up to `entry_idx`, +/// with `skip` applied to identity + conversation entries during prompt +/// assembly. The model is whichever `client` points at — the default +/// runtime client for memory-ablation alternates, a test-model client +/// for F7 comparison. +pub async fn gen_continuation( + context: &ContextState, + entry_idx: usize, + skip: F, + client: &ApiClient, +) -> anyhow::Result +where F: FnMut(&AstNode) -> bool, +{ + let (mut prompt, images, _) = context.wire_prompt(0..entry_idx, skip); + + prompt.push(tokenizer::IM_START); + prompt.extend(tokenizer::encode("assistant\n")); + + let sampling = SamplingParams { + temperature: 0.6, + top_p: 0.95, + top_k: 20, + }; + let (mut rx, _guard) = client.stream_completion_mm(&prompt, &images, sampling, Some(-5)); + + let mut tokens = Vec::new(); + while let Some(tok) = rx.recv().await { + match tok { + StreamToken::Token(id) => tokens.push(id), + StreamToken::Done { .. } => break, + StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), + } + } + + Ok(tokenizer::decode(&tokens)) +} diff --git a/src/subconscious/learn.rs b/src/subconscious/learn.rs index 26c854b..b7656bf 100644 --- a/src/subconscious/learn.rs +++ b/src/subconscious/learn.rs @@ -16,16 +16,13 @@ use crate::agent::api::ApiClient; use crate::agent::context::{ - Ast, AstNode, ContextState, Role, WireImage, is_memory_node, memory_key, + Ast, AstNode, ContextState, Role, WireImage, + is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context, }; -use crate::agent::tokenizer; +use crate::subconscious::generate::gen_continuation; const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); -fn is_assistant(node: &AstNode) -> bool { - matches!(node, AstNode::Branch { role: Role::Assistant, .. }) -} - // ── Score API ─────────────────────────────────────────────────── #[derive(serde::Deserialize)] @@ -414,50 +411,6 @@ pub async fn score_finetune( Ok(results) } -/// Concatenate the text of a Branch's Leaf children — what the model -/// actually produced on that turn (Content + Thinking + ToolCall name). -fn render_branch_text(children: &[AstNode]) -> String { - children.iter() - .filter_map(|c| match c { - AstNode::Leaf(leaf) => Some(leaf.body().text().to_string()), - _ => None, - }) - .collect::>() - .join("") -} - -/// Render the last `max_msgs` user/assistant branches before `idx` as a -/// review-friendly string with `[user]` / `[assistant]` markers. -fn render_prior_context(entries: &[AstNode], idx: usize, max_msgs: usize) -> String { - use crate::agent::context::Role; - let mut picked: Vec<&AstNode> = Vec::with_capacity(max_msgs); - for i in (0..idx).rev() { - if picked.len() >= max_msgs { break; } - if let AstNode::Branch { role, .. } = &entries[i] { - if matches!(role, Role::User | Role::Assistant) { - picked.push(&entries[i]); - } - } - } - picked.reverse(); - - let mut out = String::new(); - for node in picked { - if let AstNode::Branch { role, children, .. } = node { - let marker = match role { - Role::User => "[user]", - Role::Assistant => "[assistant]", - _ => continue, - }; - out.push_str(marker); - out.push('\n'); - out.push_str(render_branch_text(children).trim()); - out.push_str("\n\n"); - } - } - out.trim_end().to_string() -} - /// Enriched finetune candidate with context for review. #[derive(Clone, Debug)] pub struct FinetuneCandidate { @@ -556,7 +509,7 @@ pub async fn score_finetune_candidates( activity.update( format!("finetune: generating alternate {}/{}", i + 1, total) ).await; - match generate_alternate(context, candidate.entry_idx, client).await { + match gen_continuation(context, candidate.entry_idx, is_memory_node, client).await { Ok(text) => candidate.alternate_text = Some(text), Err(e) => dbglog!("[finetune] alternate generation failed: {:#}", e), } @@ -567,42 +520,6 @@ pub async fn score_finetune_candidates( Ok((total, max_divergence)) } -/// Generate what the model would say without memories for a given entry. -async fn generate_alternate( - context: &ContextState, - entry_idx: usize, - client: &ApiClient, -) -> anyhow::Result { - use crate::agent::api::{SamplingParams, StreamToken}; - - // Build context tokens without memories, up to the response - let (mut prompt, images, _) = - context.wire_prompt(0..entry_idx, is_memory_node); - - // Add assistant turn start - prompt.push(tokenizer::IM_START); - prompt.extend(tokenizer::encode("assistant\n")); - - // Generate completion - let sampling = SamplingParams { - temperature: 0.6, - top_p: 0.95, - top_k: 20, - }; - let (mut rx, _guard) = client.stream_completion_mm(&prompt, &images, sampling, Some(-5)); - - let mut tokens = Vec::new(); - while let Some(tok) = rx.recv().await { - match tok { - StreamToken::Token(id) => tokens.push(id), - StreamToken::Done { .. } => break, - StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), - } - } - - Ok(tokenizer::decode(&tokens)) -} - // ── Finetune config and persistence ───────────────────────────── use std::path::PathBuf; diff --git a/src/subconscious/mod.rs b/src/subconscious/mod.rs index 433f721..d50f833 100644 --- a/src/subconscious/mod.rs +++ b/src/subconscious/mod.rs @@ -3,5 +3,6 @@ pub mod daemon; pub mod defs; pub mod digest; +pub mod generate; pub mod learn; pub mod prompts;