agent: share one tonic Channel + migrate scoring to gRPC Generate

Two changes that bolt together — the shared connection means the new
scoring path actually costs one HTTP/2 handshake across the whole
process instead of one-per-RPC.

ApiClient gains `salience_channel: Arc<OnceCell<Channel>>`. First
call to `ApiClient::salience_client()` opens the channel via
`connect_channel()` and stores the Channel; subsequent calls clone
it (cheap — tonic multiplexes concurrent RPCs over the single
HTTP/2 connection). Every ApiClient clone shares the same OnceCell,
so all agents spawned from Mind's client — plus every ephemeral
scoring session — reuse one connection.

SessionHandle refactored to hold an `ApiClient` clone instead of
a bag of (base_url, api_key) strings. `open` / `append_image` /
`generate` go through `self.client.salience_client()` now. New
`prefill_only(tokens)` method encapsulates the "Generate with
max_tokens=0 to append text" pattern (previously a private free
function in api/mod.rs called `flush_pending`). Drop impl on
SessionHandle stays — still fires CloseSession on the shared
channel in a detached task.

`run_session_generate` switched from `(base_url, api_key, model)`
to `&ApiClient`; the agent-turn flow that uses it keeps the same
shape but `stream_session_mm` clones the ApiClient into the
spawned worker.

learn.rs migrated from the HTTP `/v1/score` endpoint to a gRPC
session-based score:

  * `call_score` opens an ephemeral SessionHandle on the client,
    converts (prompt_tokens, images) → Vec<WireChunk> via the new
    `prompt_to_chunks` helper (splits on VISION_START/VISION_END),
    walks chunks calling `prefill_only` + `append_image`, runs a
    final Generate with `max_tokens=0` + `logprobs_ranges` over
    the scored positions, and sums each Token event's
    `sampled_logprob` per range to produce `ScoreResult`s.

  * SessionHandle drops at end of scope → CloseSession auto-fires,
    keeping the server's session map clean between calls.

  * No more HTTP path, no more `http_client()` helper, no more
    `ScoreResponse` / serde plumbing for /v1/score.

  * `send_to_train` still uses HTTP (it talks to /v1/train which
    isn't on the gRPC protocol); its ad-hoc HTTP client lives
    inline now instead of reaching for the deleted `http_client()`.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-24 12:51:53 -04:00
commit 4feebb7bc4
3 changed files with 268 additions and 213 deletions

View file

@ -1,100 +1,166 @@
// training.rs — Memory importance scoring via /v1/score
// learn.rs — Memory importance scoring over the salience gRPC protocol.
//
// Three scoring modes, all built on the same call_score() primitive:
// Three scoring modes, all built on call_score():
//
// score_memories() — Full N×M matrix (memories × responses) for the
// debug screen. Expensive: N+1 API calls.
// debug screen. Expensive: N+1 sessions/calls.
//
// memory_score() — Single memory importance. Scores the 50 messages
// score_memory() — Single memory importance. Scores the 50 messages
// after it was surfaced, with/without that memory.
// 2 API calls.
// 2 calls.
//
// finetune_score() — Identifies training candidates. Scores recent
// messages with all memories stripped. Responses
// with high divergence depend on memories the model
// hasn't internalized. 2 API calls.
// hasn't internalized. 2 calls.
//
// Each call opens an ephemeral gRPC session (reusing the shared
// tonic Channel on `ApiClient`), pushes the prompt through as
// interleaved tokens + AppendImage calls, runs Generate with
// max_tokens=0 + logprobs_ranges over the scored positions, collects
// each Token event's sampled_logprob, then drops the SessionHandle —
// which triggers a best-effort CloseSession over the shared channel.
use std::sync::Arc;
use crate::agent::api::ApiClient;
use crate::agent::api::salience::{SessionHandle, pb};
use crate::agent::context::{
Ast, AstNode, ContextState, Role, WireImage,
Ast, AstNode, ContextState, Role, WireChunk, WireImage,
is_assistant, is_memory_node, memory_key, render_branch_text, render_prior_context,
};
use crate::agent::tokenizer;
use crate::mind::{MindState, MindTriggered, TaskHandle};
use crate::subconscious::generate::gen_continuation;
const SCORE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
// ── Score API ───────────────────────────────────────────────────
#[derive(serde::Deserialize)]
#[derive(Debug, Clone)]
struct ScoreResult {
total_logprob: f64,
}
#[derive(serde::Deserialize)]
struct ScoreResponse {
scores: Vec<ScoreResult>,
}
fn http_client() -> crate::agent::api::http::HttpClient {
crate::agent::api::http::HttpClient::builder()
.timeout(SCORE_TIMEOUT)
.build()
/// Convert a flat (prompt_tokens, images) pair into the interleaved
/// chunks the session protocol expects. Tokens up to the next
/// `<|vision_start|>` become a Tokens chunk; each
/// `<|vision_start|>..<|vision_end|>` run collapses into one Image
/// chunk paired by position with the next entry in `images`. The
/// server re-expands the IMAGE_PADs on AppendImage.
fn prompt_to_chunks(prompt: &[u32], images: &[WireImage]) -> Vec<WireChunk> {
let mut out: Vec<WireChunk> = Vec::new();
let mut cur = 0;
let mut img_idx = 0;
while cur < prompt.len() {
if prompt[cur] == tokenizer::VISION_START {
let end_rel = prompt[cur..].iter()
.position(|&t| t == tokenizer::VISION_END)
.unwrap_or_else(|| panic!(
"unmatched VISION_START at position {} in prompt", cur));
let end = cur + end_rel + 1;
let img = images.get(img_idx)
.unwrap_or_else(|| panic!(
"image index {} out of range for {} images", img_idx, images.len()));
out.push(WireChunk::Image {
bytes: img.bytes.clone(),
mime: img.mime.clone(),
known_expanded_len: (end - cur) as u32,
});
img_idx += 1;
cur = end;
} else {
let next_vs = prompt[cur..].iter()
.position(|&t| t == tokenizer::VISION_START);
let end = match next_vs {
Some(o) => cur + o,
None => prompt.len(),
};
out.push(WireChunk::Tokens(prompt[cur..end].to_vec()));
cur = end;
}
}
out
}
async fn call_score(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
prompt: &[u32],
images: &[WireImage],
ranges: &[(usize, usize)],
priority: Option<i32>,
) -> anyhow::Result<Vec<ScoreResult>> {
use futures::StreamExt;
// Nothing to score — skip the round-trip.
if ranges.is_empty() {
return Ok(Vec::new());
}
let url = format!("{}/score", client.base_url());
let auth = format!("Bearer {}", client.api_key());
let mut body = serde_json::json!({
"model": client.model,
"prompt": prompt,
"score_ranges": ranges,
"logprobs": 1,
});
if !images.is_empty() {
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD;
let uris: Vec<String> = images.iter()
.map(|img| format!("data:{};base64,{}", img.mime, b64.encode(&img.bytes)))
.collect();
body["multi_modal_data"] = serde_json::json!({ "image": uris });
}
if let Some(p) = priority {
body["priority"] = serde_json::json!(p);
}
let response = http
.send_json("POST", &url, &[
("authorization", &auth),
], &body)
.await?;
let status = response.status();
let body: serde_json::Value = response.json().await?;
let chunks = prompt_to_chunks(prompt, images);
let mut handle = SessionHandle::open(client).await?;
if !status.is_success() {
let msg = body.get("error").and_then(|e| e.as_str()).unwrap_or("unknown error");
anyhow::bail!("score API HTTP {}: {}", status, msg);
}
if let Some(err) = body.get("error").and_then(|e| e.as_str()) {
anyhow::bail!("score API error: {}", err);
// Walk chunks: AppendImage for each image, prefill-only Generate
// for each text run between images. Accumulate any trailing text
// run into `pending` for the final logprob-generating Generate.
let mut pending: Vec<u32> = Vec::new();
for chunk in chunks {
match chunk {
WireChunk::Tokens(t) => pending.extend(t),
WireChunk::Image { bytes, mime, .. } => {
if !pending.is_empty() {
handle.prefill_only(std::mem::take(&mut pending)).await?;
}
handle.append_image(bytes, mime, false).await?;
}
}
}
let result: ScoreResponse = serde_json::from_value(body)
.map_err(|e| anyhow::anyhow!("failed to parse score response: {}", e))?;
Ok(result.scores)
// Final Generate: max_tokens=0 so the server runs prefill of the
// trailing `pending` tokens and emits Token events for each
// position covered by logprobs_ranges, then Done. logprob_top_k=0
// means "just the sampled (prompt) token's logprob" — no top-k
// alternatives, which is all call_score historically needed.
let logprobs_ranges: Vec<pb::PositionRange> = ranges.iter()
.map(|(s, e)| pb::PositionRange { start: *s as u32, end: *e as u32 })
.collect();
let req = pb::GenerateRequest {
session_id: handle.session_id.clone(),
append_tokens: pending,
offset: handle.committed_len,
truncating: false,
max_tokens: 0,
logprobs_ranges,
logprob_top_k: 0,
readout_ranges: Vec::new(),
temperature: 0.0,
top_p: 0.0,
top_k: 0,
stop_token_ids: Vec::new(),
priority: priority.unwrap_or(0),
};
let mut stream = handle.generate(req).await?;
let mut totals = vec![0.0f64; ranges.len()];
while let Some(event) = stream.next().await {
let event = event
.map_err(|s| anyhow::anyhow!("score Generate stream: {}", s))?;
let Some(inner) = event.event else { continue };
match inner {
pb::generate_event::Event::Token(t) => {
if !t.has_sampled_logprob { continue; }
let pos = t.position as usize;
for (i, (start, end)) in ranges.iter().enumerate() {
if pos >= *start && pos < *end {
totals[i] += t.sampled_logprob as f64;
}
}
}
pb::generate_event::Event::Done(_) => break,
}
}
Ok(totals.into_iter()
.map(|total_logprob| ScoreResult { total_logprob })
.collect())
}
/// Compute per-position logprob divergence: how much worse the model
@ -110,7 +176,6 @@ fn divergence(baseline: &[ScoreResult], without: &[ScoreResult]) -> Vec<f64> {
/// Score two message sets and return total divergence.
async fn score_divergence<F>(
http: &crate::agent::api::http::HttpClient,
client: &ApiClient,
context: &ContextState,
range: std::ops::Range<usize>,
@ -123,9 +188,9 @@ where F: FnMut(&AstNode) -> bool,
context.wire_prompt(range.clone(), |_| false);
let (without_tokens, without_images, without_ranges) =
context.wire_prompt(range, skip);
let baseline = call_score(http, client, &baseline_tokens, &baseline_images,
let baseline = call_score(client, &baseline_tokens, &baseline_images,
&baseline_ranges, priority).await?;
let without = call_score(http, client, &without_tokens, &without_images,
let without = call_score(client, &without_tokens, &without_images,
&without_ranges, priority).await?;
let divs = divergence(&baseline, &without);
Ok((divs, baseline))
@ -162,14 +227,13 @@ pub async fn score_memories(
dbglog!("[scoring-full] starting: {} memories × {} responses",
total, response_indices.len());
let http = http_client();
let activity = crate::agent::start_activity(agent, "scoring: baseline").await;
let (baseline_tokens, baseline_images, baseline_ranges) = {
let ctx = agent.context.lock().await;
ctx.wire_prompt(0..ctx.conversation().len(), |_| false)
};
let baseline = call_score(&http, client, &baseline_tokens, &baseline_images,
let baseline = call_score(client, &baseline_tokens, &baseline_images,
&baseline_ranges, Some(5)).await?;
dbglog!("[scoring-full] baseline done ({} response scores)", baseline.len());
@ -180,7 +244,7 @@ pub async fn score_memories(
let ctx = agent.context.lock().await;
ctx.wire_prompt(0..ctx.conversation().len(), |n| memory_key(n) == Some(key.as_str()))
};
let row = match call_score(&http, client, &tokens, &images, &ranges, Some(5)).await {
let row = match call_score(client, &tokens, &images, &ranges, Some(5)).await {
Ok(without) => {
let divs = divergence(&baseline, &without);
let max_div = divs.iter().cloned().fold(0.0f64, f64::max);
@ -263,8 +327,7 @@ pub async fn score_memory(
return Ok(0.0);
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range,
let (divs, _) = score_divergence(client, context, range,
|n| memory_key(n) == Some(key), Some(5)).await?;
Ok(divs.iter().sum())
@ -322,7 +385,6 @@ where
// Score oldest-first
candidates.sort_by_key(|&(_, _, last)| last);
let http = http_client();
let mut scored = 0;
let entries = context.conversation();
@ -357,7 +419,7 @@ where
}
activity.update(format!("scoring: {}/{} {}", scored + 1, total, key)).await;
match score_divergence(&http, client, context, range,
match score_divergence(client, context, range,
|n| memory_key(n) == Some(key), Some(5)).await {
Ok((divs, _)) => {
let n_responses = divs.len();
@ -505,8 +567,7 @@ pub async fn score_finetune(
return Ok(Vec::new());
}
let http = http_client();
let (divs, _) = score_divergence(&http, client, context, range, is_memory_node, Some(5)).await?;
let (divs, _) = score_divergence(client, context, range, is_memory_node, Some(5)).await?;
let mut results: Vec<(usize, f64)> = response_positions.iter()
.enumerate()
@ -804,8 +865,10 @@ pub async fn send_to_train(
}
});
let http = http_client();
let url = format!("{}/train", client.base_url());
let http = crate::agent::api::http::HttpClient::builder()
.timeout(std::time::Duration::from_secs(300))
.build();
let response = http.send_json("POST", &url, &[], &body).await?;
let status = response.status();