switch memory scoring to /v1/score endpoint
Replace prompt_logprobs-based scoring with the new vLLM /v1/score endpoint. Much simpler: one API call per memory drop, returns per-message total_logprob directly. No chunking needed, no OOM risk — the endpoint only computes logits for scored tokens. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
249726599b
commit
e8c3ed3d96
2 changed files with 99 additions and 203 deletions
|
|
@ -228,6 +228,15 @@ impl Message {
|
||||||
self.content.as_ref().map_or("", |c| c.as_text())
|
self.content.as_ref().map_or("", |c| c.as_text())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn role_str(&self) -> &str {
|
||||||
|
match self.role {
|
||||||
|
Role::System => "system",
|
||||||
|
Role::User => "user",
|
||||||
|
Role::Assistant => "assistant",
|
||||||
|
Role::Tool => "tool",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn now() -> Option<String> {
|
fn now() -> Option<String> {
|
||||||
Some(Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
|
Some(Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
// training.rs — Memory importance scoring via prompt logprobs
|
// training.rs — Memory importance scoring via /v1/score
|
||||||
//
|
//
|
||||||
// Drops each memory from the context one at a time, runs prompt_logprobs
|
// Drops each memory from the context one at a time, calls the vLLM
|
||||||
// to see how the model's confidence in its responses changes. Produces
|
// /v1/score endpoint to get logprobs for assistant responses.
|
||||||
// a divergence matrix: memories × responses.
|
// Produces a divergence matrix: memories × responses.
|
||||||
//
|
//
|
||||||
// Row sums = memory importance (for graph weight updates)
|
// Row sums = memory importance (for graph weight updates)
|
||||||
// Column sums = response memory-dependence (training candidates)
|
// Column sums = response memory-dependence (training candidates)
|
||||||
|
|
||||||
use crate::agent::api::ApiClient;
|
use crate::agent::api::ApiClient;
|
||||||
use crate::agent::types::*;
|
use crate::agent::types::*;
|
||||||
use crate::agent::ui_channel::UiSender;
|
use crate::agent::ui_channel::{UiMessage, UiSender};
|
||||||
|
|
||||||
/// Result of scoring one conversation's memory usage.
|
/// Result of scoring one conversation's memory usage.
|
||||||
pub struct MemoryScore {
|
pub struct MemoryScore {
|
||||||
|
|
@ -21,13 +21,12 @@ pub struct MemoryScore {
|
||||||
pub matrix: Vec<Vec<f64>>,
|
pub matrix: Vec<Vec<f64>>,
|
||||||
/// Keys of memories that were scored
|
/// Keys of memories that were scored
|
||||||
pub memory_keys: Vec<String>,
|
pub memory_keys: Vec<String>,
|
||||||
/// Conversation entry indices of the assistant responses (maps response_idx → entry_idx)
|
/// Conversation entry indices of the assistant responses
|
||||||
pub response_entry_indices: Vec<usize>,
|
pub response_entry_indices: Vec<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MemoryScore {
|
impl MemoryScore {
|
||||||
/// Get the most important memories for a given conversation entry index.
|
/// Get the most important memories for a given conversation entry index.
|
||||||
/// Returns (memory_key, divergence_score) sorted by importance.
|
|
||||||
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
|
pub fn important_memories_for_entry(&self, entry_idx: usize) -> Vec<(&str, f64)> {
|
||||||
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
|
let Some(resp_idx) = self.response_entry_indices.iter().position(|&i| i == entry_idx)
|
||||||
else { return Vec::new() };
|
else { return Vec::new() };
|
||||||
|
|
@ -45,17 +44,11 @@ impl MemoryScore {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Score how important each memory is to the conversation.
|
/// Score how important each memory is to the conversation.
|
||||||
///
|
|
||||||
/// For each Memory entry in the context, builds a version without it
|
|
||||||
/// and checks how the model's logprobs change for assistant responses.
|
|
||||||
pub async fn score_memories(
|
pub async fn score_memories(
|
||||||
context: &ContextState,
|
context: &ContextState,
|
||||||
client: &ApiClient,
|
client: &ApiClient,
|
||||||
ui_tx: &UiSender,
|
ui_tx: &UiSender,
|
||||||
) -> anyhow::Result<MemoryScore> {
|
) -> anyhow::Result<MemoryScore> {
|
||||||
use crate::agent::ui_channel::UiMessage;
|
|
||||||
|
|
||||||
// Identify memory entries and assistant response positions
|
|
||||||
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
|
let memories: Vec<(usize, String)> = context.entries.iter().enumerate()
|
||||||
.filter_map(|(i, e)| match e {
|
.filter_map(|(i, e)| match e {
|
||||||
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
|
ConversationEntry::Memory { key, .. } => Some((i, key.clone())),
|
||||||
|
|
@ -79,22 +72,32 @@ pub async fn score_memories(
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
"[training] scoring {} memories × {} responses",
|
"[training] scoring {} memories × {} responses via /v1/score",
|
||||||
memories.len(), response_indices.len(),
|
memories.len(), response_indices.len(),
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// Shared HTTP client for connection reuse across all scoring calls
|
|
||||||
let http = reqwest::Client::builder()
|
let http = reqwest::Client::builder()
|
||||||
.pool_max_idle_per_host(2)
|
.pool_max_idle_per_host(2)
|
||||||
.build()
|
.build()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Baseline: logprobs with all memories present
|
// Build the messages array from context
|
||||||
let baseline = get_response_logprobs(context, &context.entries, client, &http, ui_tx).await?;
|
let all_messages = build_messages(context);
|
||||||
|
|
||||||
|
let roles: Vec<&str> = all_messages.iter()
|
||||||
|
.map(|m| m.get("role").and_then(|r| r.as_str()).unwrap_or("?"))
|
||||||
|
.collect();
|
||||||
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] sending {} messages, roles: {:?}",
|
||||||
|
all_messages.len(), roles,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Baseline: score with all memories present
|
||||||
|
let baseline = call_score(&http, client, &all_messages, ui_tx).await?;
|
||||||
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
"[training] baseline: {} response tokens scored",
|
"[training] baseline: {} messages scored",
|
||||||
baseline.iter().map(|r| r.len()).sum::<usize>(),
|
baseline.len(),
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// For each memory, drop it and measure divergence
|
// For each memory, drop it and measure divergence
|
||||||
|
|
@ -110,28 +113,27 @@ pub async fn score_memories(
|
||||||
mem_idx + 1, memories.len(), key,
|
mem_idx + 1, memories.len(), key,
|
||||||
)));
|
)));
|
||||||
|
|
||||||
// Build entries without this memory
|
// Build messages without this memory
|
||||||
let filtered: Vec<ConversationEntry> = context.entries.iter().enumerate()
|
let filtered_messages = build_messages_without(context, *entry_idx);
|
||||||
.filter(|(i, _)| *i != *entry_idx)
|
let without = call_score(&http, client, &filtered_messages, ui_tx).await?;
|
||||||
.map(|(_, e)| e.clone())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let without = get_response_logprobs(context, &filtered, client, &http, ui_tx).await?;
|
// Match scores by message index and compute divergence
|
||||||
|
|
||||||
// Compute per-response divergence
|
|
||||||
let mut row = Vec::new();
|
let mut row = Vec::new();
|
||||||
for (_resp_idx, (base_lps, without_lps)) in baseline.iter().zip(without.iter()).enumerate() {
|
for base_score in &baseline {
|
||||||
// Sum of logprob drops across tokens in this response
|
let base_lp = base_score.total_logprob;
|
||||||
|
let without_lp = without.iter()
|
||||||
|
.find(|s| s.message_index == base_score.message_index)
|
||||||
|
.map(|s| s.total_logprob)
|
||||||
|
.unwrap_or(base_lp);
|
||||||
// Positive = memory helped (logprob was higher with it)
|
// Positive = memory helped (logprob was higher with it)
|
||||||
let divergence: f64 = base_lps.iter().zip(without_lps.iter())
|
let divergence = (base_lp - without_lp).max(0.0);
|
||||||
.map(|(b, w)| b - w) // positive when baseline was more confident
|
|
||||||
.filter(|d| *d > 0.0) // only count where memory helped
|
|
||||||
.sum();
|
|
||||||
row.push(divergence);
|
row.push(divergence);
|
||||||
}
|
}
|
||||||
matrix.push(row);
|
matrix.push(row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let _ = ui_tx.send(UiMessage::Activity(String::new()));
|
||||||
|
|
||||||
// Compute scores
|
// Compute scores
|
||||||
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
|
let memory_weights: Vec<(String, f64)> = memory_keys.iter()
|
||||||
.zip(matrix.iter())
|
.zip(matrix.iter())
|
||||||
|
|
@ -148,35 +150,13 @@ pub async fn score_memories(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = ui_tx.send(UiMessage::Activity(String::new()));
|
// Log summary
|
||||||
|
|
||||||
// Log summary per memory
|
|
||||||
for (key, score) in &memory_weights {
|
for (key, score) in &memory_weights {
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
"[training] {} → importance {:.1}", key, score,
|
"[training] {} → importance {:.1}", key, score,
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log per-response breakdown for the most important memories
|
|
||||||
let mut sorted_mems: Vec<(usize, &str, f64)> = memory_keys.iter().enumerate()
|
|
||||||
.map(|(i, k)| (i, k.as_str(), memory_weights[i].1))
|
|
||||||
.collect();
|
|
||||||
sorted_mems.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
|
|
||||||
|
|
||||||
for (mem_i, key, total) in sorted_mems.iter().take(5) {
|
|
||||||
if *total <= 0.0 { continue; }
|
|
||||||
let row = &matrix[*mem_i];
|
|
||||||
let top_responses: Vec<String> = row.iter().enumerate()
|
|
||||||
.filter(|(_, v)| **v > 0.1)
|
|
||||||
.map(|(j, v)| format!("resp[{}]={:.1}", j, v))
|
|
||||||
.collect();
|
|
||||||
if !top_responses.is_empty() {
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
||||||
"[training] {} ({:.1}): {}", key, total, top_responses.join(", "),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(MemoryScore {
|
Ok(MemoryScore {
|
||||||
memory_weights,
|
memory_weights,
|
||||||
response_scores,
|
response_scores,
|
||||||
|
|
@ -186,116 +166,70 @@ pub async fn score_memories(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rough token estimate: ~4 chars per token.
|
/// Score response from the /v1/score endpoint.
|
||||||
const CHARS_PER_TOKEN: usize = 4;
|
#[derive(serde::Deserialize)]
|
||||||
|
struct ScoreMessageResult {
|
||||||
|
message_index: usize,
|
||||||
|
total_logprob: f64,
|
||||||
|
}
|
||||||
|
|
||||||
/// Get logprobs for all assistant response tokens in a conversation.
|
#[derive(serde::Deserialize)]
|
||||||
/// Returns a Vec<Vec<f64>> — one inner vec per assistant response,
|
struct ScoreApiResponse {
|
||||||
/// containing logprobs for each token in that response.
|
scores: Vec<ScoreMessageResult>,
|
||||||
///
|
}
|
||||||
/// Chunks the conversation into ~50K token segments (rounded to
|
|
||||||
/// assistant message boundaries) to avoid OOM from the logprobs
|
/// Build the messages array for the /v1/score endpoint from ContextState.
|
||||||
/// tensor allocation.
|
fn build_messages(context: &ContextState) -> Vec<serde_json::Value> {
|
||||||
async fn get_response_logprobs(
|
let mut msgs = Vec::new();
|
||||||
context: &ContextState,
|
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
|
||||||
entries: &[ConversationEntry],
|
|
||||||
client: &ApiClient,
|
|
||||||
http: &reqwest::Client,
|
|
||||||
ui_tx: &UiSender,
|
|
||||||
) -> anyhow::Result<Vec<Vec<f64>>> {
|
|
||||||
// Build the fixed prefix (system prompt + personality)
|
|
||||||
let mut prefix = Vec::new();
|
|
||||||
prefix.push(Message::system(&context.system_prompt));
|
|
||||||
let ctx = context.render_context_message();
|
let ctx = context.render_context_message();
|
||||||
if !ctx.is_empty() {
|
if !ctx.is_empty() {
|
||||||
prefix.push(Message::user(ctx));
|
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
||||||
}
|
}
|
||||||
let prefix_chars: usize = prefix.iter()
|
for entry in &context.entries {
|
||||||
.map(|m| m.content_text().len())
|
let m = entry.api_message();
|
||||||
.sum();
|
msgs.push(serde_json::json!({
|
||||||
|
"role": m.role_str(),
|
||||||
// Split entries into chunks that fit within the token budget,
|
"content": m.content_text(),
|
||||||
// each ending at an assistant message boundary.
|
}));
|
||||||
let max_chunk_chars = crate::config::get().scoring_chunk_tokens * CHARS_PER_TOKEN;
|
|
||||||
let budget = max_chunk_chars.saturating_sub(prefix_chars);
|
|
||||||
let chunks = chunk_entries(entries, budget);
|
|
||||||
|
|
||||||
let mut all_responses: Vec<Vec<f64>> = Vec::new();
|
|
||||||
|
|
||||||
use crate::agent::ui_channel::UiMessage;
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
||||||
"[training] {} chunks, prefix={}K chars, budget={}K chars",
|
|
||||||
chunks.len(), prefix_chars / 1024, budget / 1024,
|
|
||||||
)));
|
|
||||||
|
|
||||||
for (chunk_idx, chunk) in chunks.iter().enumerate() {
|
|
||||||
let chunk_chars: usize = chunk.iter()
|
|
||||||
.map(|e| e.message().content_text().len())
|
|
||||||
.sum();
|
|
||||||
let _ = ui_tx.send(UiMessage::Debug(format!(
|
|
||||||
"[training] chunk {}/{}: {} entries, {}K chars",
|
|
||||||
chunk_idx + 1, chunks.len(), chunk.len(), chunk_chars / 1024,
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
msgs
|
||||||
for chunk in &chunks {
|
|
||||||
let mut msgs = prefix.clone();
|
|
||||||
msgs.extend(chunk.iter().map(|e| e.api_message().clone()));
|
|
||||||
|
|
||||||
let result = call_prompt_logprobs(&msgs, client, http).await?;
|
|
||||||
all_responses.extend(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(all_responses)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Split entries into chunks of approximately `budget_chars` each,
|
/// Build messages with one entry removed.
|
||||||
/// ending at assistant message boundaries.
|
fn build_messages_without(context: &ContextState, skip_idx: usize) -> Vec<serde_json::Value> {
|
||||||
fn chunk_entries(entries: &[ConversationEntry], budget_chars: usize) -> Vec<Vec<ConversationEntry>> {
|
let mut msgs = Vec::new();
|
||||||
let mut chunks = Vec::new();
|
msgs.push(serde_json::json!({"role": "system", "content": &context.system_prompt}));
|
||||||
let mut current = Vec::new();
|
let ctx = context.render_context_message();
|
||||||
let mut current_chars = 0;
|
if !ctx.is_empty() {
|
||||||
|
msgs.push(serde_json::json!({"role": "user", "content": ctx}));
|
||||||
for entry in entries {
|
|
||||||
let entry_chars = entry.message().content_text().len();
|
|
||||||
current_chars += entry_chars;
|
|
||||||
current.push(entry.clone());
|
|
||||||
|
|
||||||
// If over budget and we just added an assistant message, cut here
|
|
||||||
if current_chars >= budget_chars && entry.message().role == Role::Assistant {
|
|
||||||
chunks.push(std::mem::take(&mut current));
|
|
||||||
current_chars = 0;
|
|
||||||
}
|
}
|
||||||
|
for (i, entry) in context.entries.iter().enumerate() {
|
||||||
|
if i == skip_idx { continue; }
|
||||||
|
let m = entry.api_message();
|
||||||
|
msgs.push(serde_json::json!({
|
||||||
|
"role": m.role_str(),
|
||||||
|
"content": m.content_text(),
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
msgs
|
||||||
if !current.is_empty() {
|
|
||||||
chunks.push(current);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If everything fit in one chunk, just return it
|
|
||||||
if chunks.is_empty() {
|
|
||||||
chunks.push(entries.to_vec());
|
|
||||||
}
|
|
||||||
|
|
||||||
chunks
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Make a single prompt_logprobs API call and extract response logprobs.
|
/// Call the /v1/score endpoint and return per-message logprobs.
|
||||||
async fn call_prompt_logprobs(
|
async fn call_score(
|
||||||
msgs: &[Message],
|
|
||||||
client: &ApiClient,
|
|
||||||
http: &reqwest::Client,
|
http: &reqwest::Client,
|
||||||
) -> anyhow::Result<Vec<Vec<f64>>> {
|
client: &ApiClient,
|
||||||
|
messages: &[serde_json::Value],
|
||||||
|
ui_tx: &UiSender,
|
||||||
|
) -> anyhow::Result<Vec<ScoreMessageResult>> {
|
||||||
let request = serde_json::json!({
|
let request = serde_json::json!({
|
||||||
"model": client.model,
|
"model": client.model,
|
||||||
"messages": msgs,
|
"messages": messages,
|
||||||
"max_tokens": 1,
|
"logprobs": 1,
|
||||||
"prompt_logprobs": 1,
|
|
||||||
"stream": false,
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let response = http
|
let response = http
|
||||||
.post(format!("{}/chat/completions", client.base_url()))
|
.post(format!("{}/score", client.base_url()))
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.header("Authorization", format!("Bearer {}", client.api_key()))
|
.header("Authorization", format!("Bearer {}", client.api_key()))
|
||||||
.json(&request)
|
.json(&request)
|
||||||
|
|
@ -307,61 +241,14 @@ async fn call_prompt_logprobs(
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
let msg = body.get("error")
|
let msg = body.get("error")
|
||||||
.and_then(|e| e.get("message"))
|
.and_then(|e| e.as_str())
|
||||||
.and_then(|m| m.as_str())
|
|
||||||
.unwrap_or("unknown error");
|
.unwrap_or("unknown error");
|
||||||
anyhow::bail!("HTTP {} from logprobs API: {}", status, msg);
|
let _ = ui_tx.send(UiMessage::Debug(format!(
|
||||||
|
"[training] score API error: {}", msg,
|
||||||
|
)));
|
||||||
|
anyhow::bail!("score API error: {}", msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt_logprobs = body.get("prompt_logprobs")
|
let result: ScoreApiResponse = serde_json::from_value(body)?;
|
||||||
.and_then(|v| v.as_array())
|
Ok(result.scores)
|
||||||
.ok_or_else(|| anyhow::anyhow!("no prompt_logprobs in response"))?;
|
|
||||||
|
|
||||||
// Find assistant response boundaries using special tokens
|
|
||||||
// Pattern: <|im_start|> assistant \n [<think>...</think>] response <|im_end|>
|
|
||||||
let mut responses: Vec<Vec<f64>> = Vec::new();
|
|
||||||
let mut in_assistant = false;
|
|
||||||
let mut in_think = false;
|
|
||||||
let mut current_response: Vec<f64> = Vec::new();
|
|
||||||
|
|
||||||
for entry in prompt_logprobs {
|
|
||||||
let Some(obj) = entry.as_object() else { continue };
|
|
||||||
let first = obj.values().next();
|
|
||||||
let Some(info) = first.and_then(|v| v.as_object()) else { continue };
|
|
||||||
let token = info.get("decoded_token").and_then(|v| v.as_str()).unwrap_or("");
|
|
||||||
let logprob = info.get("logprob").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
|
||||||
|
|
||||||
match token {
|
|
||||||
"<|im_start|>" => {
|
|
||||||
in_assistant = false;
|
|
||||||
in_think = false;
|
|
||||||
}
|
|
||||||
"assistant" if !in_assistant => {
|
|
||||||
in_assistant = true;
|
|
||||||
in_think = false;
|
|
||||||
current_response.clear();
|
|
||||||
}
|
|
||||||
"<think>" if in_assistant => {
|
|
||||||
in_think = true;
|
|
||||||
}
|
|
||||||
"</think>" if in_assistant => {
|
|
||||||
in_think = false;
|
|
||||||
}
|
|
||||||
"<|im_end|>" if in_assistant => {
|
|
||||||
if !current_response.is_empty() {
|
|
||||||
responses.push(std::mem::take(&mut current_response));
|
|
||||||
}
|
|
||||||
in_assistant = false;
|
|
||||||
}
|
|
||||||
"\n" if in_assistant && current_response.is_empty() => {
|
|
||||||
// Skip the newline right after "assistant"
|
|
||||||
}
|
|
||||||
_ if in_assistant && !in_think => {
|
|
||||||
current_response.push(logprob);
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(responses)
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue