From 29b3aeca57ca199c3e628b386b09b1795fc7b268 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Thu, 2 Apr 2026 22:35:29 -0400 Subject: [PATCH] chunk scoring calls to avoid OOM on large contexts Split conversation into ~50K token chunks (configurable via scoring_chunk_tokens in config) for prompt_logprobs calls. Each chunk ends at an assistant message boundary. Avoids the ~40GB logprobs tensor allocation that OOM'd on full contexts. Co-Authored-By: Proof of Concept --- src/config.rs | 5 +++ src/thought/training.rs | 75 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/src/config.rs b/src/config.rs index 9a10e24..f1b637b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -55,6 +55,7 @@ pub struct ContextGroup { fn default_true() -> bool { true } fn default_context_window() -> usize { 128_000 } fn default_stream_timeout() -> u64 { 60 } +fn default_scoring_chunk_tokens() -> usize { 50_000 } fn default_identity_dir() -> PathBuf { dirs::home_dir().unwrap_or_default().join(".consciousness/identity") } @@ -95,6 +96,9 @@ pub struct Config { /// Stream chunk timeout in seconds (no data = timeout). #[serde(default = "default_stream_timeout")] pub api_stream_timeout_secs: u64, + /// Max tokens per chunk for memory scoring logprobs calls. + #[serde(default = "default_scoring_chunk_tokens")] + pub scoring_chunk_tokens: usize, pub api_reasoning: String, pub agent_types: Vec, /// Surface agent timeout in seconds. @@ -143,6 +147,7 @@ impl Default for Config { api_model: None, api_context_window: default_context_window(), api_stream_timeout_secs: default_stream_timeout(), + scoring_chunk_tokens: default_scoring_chunk_tokens(), agent_model: None, api_reasoning: "high".to_string(), agent_types: vec![ diff --git a/src/thought/training.rs b/src/thought/training.rs index 2cfe73a..45deced 100644 --- a/src/thought/training.rs +++ b/src/thought/training.rs @@ -180,24 +180,87 @@ pub async fn score_memories( }) } +/// Rough token estimate: ~4 chars per token. +const CHARS_PER_TOKEN: usize = 4; + /// Get logprobs for all assistant response tokens in a conversation. /// Returns a Vec> — one inner vec per assistant response, /// containing logprobs for each token in that response. +/// +/// Chunks the conversation into ~50K token segments (rounded to +/// assistant message boundaries) to avoid OOM from the logprobs +/// tensor allocation. async fn get_response_logprobs( context: &ContextState, entries: &[ConversationEntry], client: &ApiClient, ) -> anyhow::Result>> { - // Assemble messages the same way the runner does - let mut msgs = Vec::new(); - msgs.push(Message::system(&context.system_prompt)); + // Build the fixed prefix (system prompt + personality) + let mut prefix = Vec::new(); + prefix.push(Message::system(&context.system_prompt)); let ctx = context.render_context_message(); if !ctx.is_empty() { - msgs.push(Message::user(ctx)); + prefix.push(Message::user(ctx)); } - msgs.extend(entries.iter().map(|e| e.api_message().clone())); + let prefix_chars: usize = prefix.iter() + .map(|m| m.content_text().len()) + .sum(); - // Call the API with prompt_logprobs + // Split entries into chunks that fit within the token budget, + // each ending at an assistant message boundary. + let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN; + let budget = max_chunk_chars.saturating_sub(prefix_chars); + let chunks = chunk_entries(entries, budget); + + let mut all_responses: Vec> = Vec::new(); + + for chunk in &chunks { + let mut msgs = prefix.clone(); + msgs.extend(chunk.iter().map(|e| e.api_message().clone())); + + let result = call_prompt_logprobs(&msgs, client).await?; + all_responses.extend(result); + } + + Ok(all_responses) +} + +/// Split entries into chunks of approximately `budget_chars` each, +/// ending at assistant message boundaries. +fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec> { + let mut chunks = Vec::new(); + let mut current = Vec::new(); + let mut current_chars = 0; + + for entry in entries { + let entry_chars = entry.message().content_text().len(); + current_chars += entry_chars; + current.push(entry.clone()); + + // If over budget and we just added an assistant message, cut here + if current_chars >= budget_chars && entry.message().role == Role::Assistant { + chunks.push(std::mem::take(&mut current)); + current_chars = 0; + } + } + + if !current.is_empty() { + chunks.push(current); + } + + // If everything fit in one chunk, just return it + if chunks.is_empty() { + chunks.push(entries.to_vec()); + } + + chunks +} + +/// Make a single prompt_logprobs API call and extract response logprobs. +async fn call_prompt_logprobs( + msgs: &[Message], + client: &ApiClient, +) -> anyhow::Result>> { let request = serde_json::json!({ "model": client.model, "messages": msgs,