chunk scoring calls to avoid OOM on large contexts
Split conversation into ~50K token chunks (configurable via scoring_chunk_tokens in config) for prompt_logprobs calls. Each chunk ends at an assistant message boundary. Avoids the ~40GB logprobs tensor allocation that OOM'd on full contexts. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
19205b9bae
commit
29b3aeca57
2 changed files with 74 additions and 6 deletions
|
|
@ -55,6 +55,7 @@ pub struct ContextGroup {
|
||||||
fn default_true() -> bool { true }
|
fn default_true() -> bool { true }
|
||||||
fn default_context_window() -> usize { 128_000 }
|
fn default_context_window() -> usize { 128_000 }
|
||||||
fn default_stream_timeout() -> u64 { 60 }
|
fn default_stream_timeout() -> u64 { 60 }
|
||||||
|
fn default_scoring_chunk_tokens() -> usize { 50_000 }
|
||||||
fn default_identity_dir() -> PathBuf {
|
fn default_identity_dir() -> PathBuf {
|
||||||
dirs::home_dir().unwrap_or_default().join(".consciousness/identity")
|
dirs::home_dir().unwrap_or_default().join(".consciousness/identity")
|
||||||
}
|
}
|
||||||
|
|
@ -95,6 +96,9 @@ pub struct Config {
|
||||||
/// Stream chunk timeout in seconds (no data = timeout).
|
/// Stream chunk timeout in seconds (no data = timeout).
|
||||||
#[serde(default = "default_stream_timeout")]
|
#[serde(default = "default_stream_timeout")]
|
||||||
pub api_stream_timeout_secs: u64,
|
pub api_stream_timeout_secs: u64,
|
||||||
|
/// Max tokens per chunk for memory scoring logprobs calls.
|
||||||
|
#[serde(default = "default_scoring_chunk_tokens")]
|
||||||
|
pub scoring_chunk_tokens: usize,
|
||||||
pub api_reasoning: String,
|
pub api_reasoning: String,
|
||||||
pub agent_types: Vec<String>,
|
pub agent_types: Vec<String>,
|
||||||
/// Surface agent timeout in seconds.
|
/// Surface agent timeout in seconds.
|
||||||
|
|
@ -143,6 +147,7 @@ impl Default for Config {
|
||||||
api_model: None,
|
api_model: None,
|
||||||
api_context_window: default_context_window(),
|
api_context_window: default_context_window(),
|
||||||
api_stream_timeout_secs: default_stream_timeout(),
|
api_stream_timeout_secs: default_stream_timeout(),
|
||||||
|
scoring_chunk_tokens: default_scoring_chunk_tokens(),
|
||||||
agent_model: None,
|
agent_model: None,
|
||||||
api_reasoning: "high".to_string(),
|
api_reasoning: "high".to_string(),
|
||||||
agent_types: vec![
|
agent_types: vec![
|
||||||
|
|
|
||||||
|
|
@ -180,24 +180,87 @@ pub async fn score_memories(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Rough token estimate: ~4 chars per token.
|
||||||
|
const CHARS_PER_TOKEN: usize = 4;
|
||||||
|
|
||||||
/// Get logprobs for all assistant response tokens in a conversation.
|
/// Get logprobs for all assistant response tokens in a conversation.
|
||||||
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
|
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
|
||||||
/// containing logprobs for each token in that response.
|
/// containing logprobs for each token in that response.
|
||||||
|
///
|
||||||
|
/// Chunks the conversation into ~50K token segments (rounded to
|
||||||
|
/// assistant message boundaries) to avoid OOM from the logprobs
|
||||||
|
/// tensor allocation.
|
||||||
async fn get_response_logprobs(
|
async fn get_response_logprobs(
|
||||||
context: &ContextState,
|
context: &ContextState,
|
||||||
entries: &[ConversationEntry],
|
entries: &[ConversationEntry],
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
) -> anyhow::Result<Vec<Vec<f64>>> {
|
) -> anyhow::Result<Vec<Vec<f64>>> {
|
||||||
// Assemble messages the same way the runner does
|
// Build the fixed prefix (system prompt + personality)
|
||||||
let mut msgs = Vec::new();
|
let mut prefix = Vec::new();
|
||||||
msgs.push(Message::system(&context.system_prompt));
|
prefix.push(Message::system(&context.system_prompt));
|
||||||
let ctx = context.render_context_message();
|
let ctx = context.render_context_message();
|
||||||
if !ctx.is_empty() {
|
if !ctx.is_empty() {
|
||||||
msgs.push(Message::user(ctx));
|
prefix.push(Message::user(ctx));
|
||||||
}
|
}
|
||||||
msgs.extend(entries.iter().map(|e| e.api_message().clone()));
|
let prefix_chars: usize = prefix.iter()
|
||||||
|
.map(|m| m.content_text().len())
|
||||||
|
.sum();
|
||||||
|
|
||||||
// Call the API with prompt_logprobs
|
// Split entries into chunks that fit within the token budget,
|
||||||
|
// each ending at an assistant message boundary.
|
||||||
|
let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN;
|
||||||
|
let budget = max_chunk_chars.saturating_sub(prefix_chars);
|
||||||
|
let chunks = chunk_entries(entries, budget);
|
||||||
|
|
||||||
|
let mut all_responses: Vec<Vec<f64>> = Vec::new();
|
||||||
|
|
||||||
|
for chunk in &chunks {
|
||||||
|
let mut msgs = prefix.clone();
|
||||||
|
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
|
||||||
|
|
||||||
|
let result = call_prompt_logprobs(&msgs, client).await?;
|
||||||
|
all_responses.extend(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(all_responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split entries into chunks of approximately `budget_chars` each,
|
||||||
|
/// ending at assistant message boundaries.
|
||||||
|
fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec<Vec<ConversationEntry>> {
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
let mut current = Vec::new();
|
||||||
|
let mut current_chars = 0;
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
let entry_chars = entry.message().content_text().len();
|
||||||
|
current_chars += entry_chars;
|
||||||
|
current.push(entry.clone());
|
||||||
|
|
||||||
|
// If over budget and we just added an assistant message, cut here
|
||||||
|
if current_chars >= budget_chars && entry.message().role == Role::Assistant {
|
||||||
|
chunks.push(std::mem::take(&mut current));
|
||||||
|
current_chars = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
chunks.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If everything fit in one chunk, just return it
|
||||||
|
if chunks.is_empty() {
|
||||||
|
chunks.push(entries.to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make a single prompt_logprobs API call and extract response logprobs.
|
||||||
|
async fn call_prompt_logprobs(
|
||||||
|
msgs: &[Message],
|
||||||
|
client: &ApiClient,
|
||||||
|
) -> anyhow::Result<Vec<Vec<f64>>> {
|
||||||
let request = serde_json::json!({
|
let request = serde_json::json!({
|
||||||
"model": client.model,
|
"model": client.model,
|
||||||
"messages": msgs,
|
"messages": msgs,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue