forked from kent/consciousness
Two changes that bolt together — the shared connection means the new
scoring path actually costs one HTTP/2 handshake across the whole
process instead of one-per-RPC.
ApiClient gains `salience_channel: Arc<OnceCell<Channel>>`. First
call to `ApiClient::salience_client()` opens the channel via
`connect_channel()` and stores the Channel; subsequent calls clone
it (cheap — tonic multiplexes concurrent RPCs over the single
HTTP/2 connection). Every ApiClient clone shares the same OnceCell,
so all agents spawned from Mind's client — plus every ephemeral
scoring session — reuse one connection.
SessionHandle refactored to hold an `ApiClient` clone instead of
a bag of (base_url, api_key) strings. `open` / `append_image` /
`generate` go through `self.client.salience_client()` now. New
`prefill_only(tokens)` method encapsulates the "Generate with
max_tokens=0 to append text" pattern (previously a private free
function in api/mod.rs called `flush_pending`). Drop impl on
SessionHandle stays — still fires CloseSession on the shared
channel in a detached task.
`run_session_generate` switched from `(base_url, api_key, model)`
to `&ApiClient`; the agent-turn flow that uses it keeps the same
shape but `stream_session_mm` clones the ApiClient into the
spawned worker.
learn.rs migrated from the HTTP `/v1/score` endpoint to a gRPC
session-based score:
* `call_score` opens an ephemeral SessionHandle on the client,
converts (prompt_tokens, images) → Vec<WireChunk> via the new
`prompt_to_chunks` helper (splits on VISION_START/VISION_END),
walks chunks calling `prefill_only` + `append_image`, runs a
final Generate with `max_tokens=0` + `logprobs_ranges` over
the scored positions, and sums each Token event's
`sampled_logprob` per range to produce `ScoreResult`s.
* SessionHandle drops at end of scope → CloseSession auto-fires,
keeping the server's session map clean between calls.
* No more HTTP path, no more `http_client()` helper, no more
`ScoreResponse` / serde plumbing for /v1/score.
* `send_to_train` still uses HTTP (it talks to /v1/train which
isn't on the gRPC protocol); its ad-hoc HTTP client lives
inline now instead of reaching for the deleted `http_client()`.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
417 lines
16 KiB
Rust
417 lines
16 KiB
Rust
// api/ — LLM API client (OpenAI-compatible)
|
|
//
|
|
// Works with any provider that implements the OpenAI chat completions
|
|
// API: OpenRouter, vLLM, llama.cpp, Fireworks, Together, etc.
|
|
//
|
|
// Diagnostics: anomalies always logged to debug panel.
|
|
// Set POC_DEBUG=1 for verbose per-turn logging.
|
|
|
|
pub mod http;
|
|
pub mod salience;
|
|
|
|
use std::time::Duration;
|
|
use anyhow::Result;
|
|
use tokio::sync::mpsc;
|
|
use serde::Deserialize;
|
|
|
|
use http::HttpClient;
|
|
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
}
|
|
|
|
/// Concept-readout manifest returned by the vLLM server's
|
|
/// `/v1/readout/manifest` endpoint. Maps the nameless tensor indices
|
|
/// in streaming `readout` fields back to concept names and layer
|
|
/// indices.
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct ReadoutManifest {
|
|
pub concepts: Vec<String>,
|
|
pub layers: Vec<u32>,
|
|
}
|
|
|
|
/// Per-token per-layer concept projections streamed alongside each
|
|
/// sampled token. Shape `[n_layers][n_concepts]`. Named values come
|
|
/// from pairing with the manifest fetched at startup.
|
|
pub type TokenReadout = Vec<Vec<f32>>;
|
|
|
|
/// Client-side sampling state. Mirrors the wire-level fields in
|
|
/// `GenerateRequest` (proto flattened its `SamplingParams` submessage
|
|
/// in so the server handler reads them directly), but stays as a
|
|
/// grouped struct on the client because UI / config / tests pass
|
|
/// these around together.
|
|
#[derive(Clone, Copy)]
|
|
pub struct SamplingParams {
|
|
pub temperature: f32,
|
|
pub top_p: f32,
|
|
pub top_k: u32,
|
|
/// Decode budget. 0 = prefill only; >0 = decode up to this many
|
|
/// tokens, stopping early on EOS / stop_token_ids.
|
|
pub max_tokens: u32,
|
|
}
|
|
|
|
/// A JoinHandle that aborts its task when dropped.
|
|
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
|
|
|
impl Drop for AbortOnDrop {
|
|
fn drop(&mut self) {
|
|
self.0.abort();
|
|
}
|
|
}
|
|
|
|
|
|
// ─────────────────────────────────────────────────────────────
|
|
// Stream events — yielded by backends, consumed by the runner
|
|
// ─────────────────────────────────────────────────────────────
|
|
|
|
/// One token from the streaming completions API.
|
|
pub enum StreamToken {
|
|
/// A sampled token, optionally with its per-layer concept readout.
|
|
/// `readout` is `None` when the server has readout disabled or
|
|
/// returned no readout for this chunk.
|
|
Token { id: u32, readout: Option<TokenReadout> },
|
|
/// An image was committed server-side via AppendImage during this
|
|
/// stream. `placeholder_count` is the N IMAGE_PADs the server
|
|
/// wrote. Emitted in AST order — caller applies these counts to
|
|
/// the first-N image leaves that currently have token_count=0
|
|
/// via `ContextState::commit_image_token_counts`.
|
|
ImageAppended { placeholder_count: u32 },
|
|
Done { usage: Option<Usage> },
|
|
Error(String),
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ApiClient {
|
|
client: HttpClient,
|
|
api_key: String,
|
|
pub model: String,
|
|
base_url: String,
|
|
/// Cached readout manifest — fetched once per process and shared
|
|
/// across ApiClient clones (every Agent/fork gets the same cell).
|
|
/// `None` after fetch means the server has readout disabled (404).
|
|
manifest: std::sync::Arc<tokio::sync::OnceCell<Option<ReadoutManifest>>>,
|
|
/// Shared tonic Channel to the salience gRPC endpoint. Opened on
|
|
/// first use and reused across every SessionHandle / RPC call
|
|
/// derived from this ApiClient. tonic multiplexes concurrent
|
|
/// requests over the HTTP/2 connection automatically.
|
|
salience_channel: std::sync::Arc<
|
|
tokio::sync::OnceCell<tonic::transport::Channel>
|
|
>,
|
|
}
|
|
|
|
impl ApiClient {
|
|
pub fn new(base_url: &str, api_key: &str, model: &str) -> Self {
|
|
let client = HttpClient::builder()
|
|
.connect_timeout(Duration::from_secs(30))
|
|
.timeout(Duration::from_secs(600))
|
|
.build();
|
|
|
|
Self {
|
|
client,
|
|
api_key: api_key.to_string(),
|
|
model: model.to_string(),
|
|
base_url: base_url.trim_end_matches('/').to_string(),
|
|
manifest: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
|
salience_channel: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
|
}
|
|
}
|
|
|
|
/// Return a `SalienceClient` on the shared gRPC channel — opens
|
|
/// the channel on first call and reuses it thereafter across
|
|
/// every ApiClient clone. All scoring / inference / session
|
|
/// RPCs flow through this single multiplexed HTTP/2 connection.
|
|
pub async fn salience_client(&self) -> Result<
|
|
salience::pb::salience_client::SalienceClient<tonic::transport::Channel>
|
|
> {
|
|
let ch = self.salience_channel.get_or_try_init(|| async {
|
|
let grpc_url = salience::derive_grpc_url(&self.base_url);
|
|
log::debug!(target: "grpc",
|
|
"opening shared salience channel: http_base={} -> grpc_url={}",
|
|
self.base_url, grpc_url);
|
|
salience::connect_channel(&grpc_url).await
|
|
}).await?;
|
|
Ok(salience::pb::salience_client::SalienceClient::new(ch.clone()))
|
|
}
|
|
|
|
/// Stream generation via a gRPC session. Walks the prompt chunks
|
|
/// comparing against the session's `committed_len`, sends the
|
|
/// delta as interleaved `AppendImage` + intermediate
|
|
/// `Generate(max_tokens=0)` (for text runs separating images) +
|
|
/// a final `Generate(max_tokens=sampling.max_tokens, ...)` whose
|
|
/// Token events stream back through the channel.
|
|
///
|
|
/// On any gRPC error the session is dropped; the next call
|
|
/// reopens fresh. Happy-path ordering: Token* Done. Error paths
|
|
/// emit `StreamToken::Error` and close.
|
|
pub(crate) fn stream_session_mm(
|
|
&self,
|
|
session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
|
|
chunks: Vec<super::context::WireChunk>,
|
|
sampling: SamplingParams,
|
|
priority: Option<i32>,
|
|
readout_shape: Option<(u32, u32)>,
|
|
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
let client = self.clone();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
let result = run_session_generate(
|
|
session_lock, &client, chunks, sampling, priority,
|
|
readout_shape, &tx,
|
|
).await;
|
|
if let Err(e) = result {
|
|
log::warn!(target: "grpc",
|
|
"stream_session_mm error, forwarding to UI: {:#}", e);
|
|
let _ = tx.send(StreamToken::Error(format!("{:#}", e)));
|
|
}
|
|
});
|
|
|
|
(rx, AbortOnDrop(handle))
|
|
}
|
|
|
|
pub fn base_url(&self) -> &str { &self.base_url }
|
|
pub fn api_key(&self) -> &str { &self.api_key }
|
|
|
|
/// Fetch `/v1/readout/manifest` — returns `Ok(Some(..))` if
|
|
/// readout is enabled on the server, `Ok(None)` on 404 (disabled),
|
|
/// or an error on any other failure.
|
|
///
|
|
/// First call performs the HTTP fetch; subsequent calls (including
|
|
/// across ApiClient clones sharing the same cell) return the
|
|
/// cached result. The manifest doesn't change during a server run.
|
|
pub fn model_str(&self) -> &str { &self.model }
|
|
|
|
pub async fn fetch_readout_manifest(&self) -> Result<Option<ReadoutManifest>> {
|
|
let manifest = self.manifest.get_or_try_init(|| async {
|
|
let url = format!("{}/readout/manifest", self.base_url);
|
|
let auth = format!("Bearer {}", self.api_key);
|
|
let response = self
|
|
.client
|
|
.get_with_headers(&url, &[("Authorization", &auth)])
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
|
|
let status = response.status();
|
|
if status.as_u16() == 404 {
|
|
return Ok::<_, anyhow::Error>(None);
|
|
}
|
|
if !status.is_success() {
|
|
let body = response.text().await.unwrap_or_default();
|
|
let n = body.floor_char_boundary(body.len().min(500));
|
|
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
|
|
}
|
|
Ok(Some(response.json().await?))
|
|
}).await?;
|
|
Ok(manifest.clone())
|
|
}
|
|
|
|
}
|
|
|
|
/// Body of the gRPC-path streaming task. Walks the wire chunks
|
|
/// against the session's `committed_len`, sends the delta via
|
|
/// AppendImage / intermediate prefill-only Generates / final decode
|
|
/// Generate, and translates the final Generate's Token events into
|
|
/// StreamTokens on `tx`. On success the session handle is returned
|
|
/// to `session_lock` with an updated `committed_len`; on error the
|
|
/// handle is dropped so the next call reopens.
|
|
async fn run_session_generate(
|
|
session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
|
|
client: &ApiClient,
|
|
chunks: Vec<super::context::WireChunk>,
|
|
sampling: SamplingParams,
|
|
priority: Option<i32>,
|
|
readout_shape: Option<(u32, u32)>,
|
|
tx: &mpsc::UnboundedSender<StreamToken>,
|
|
) -> Result<()> {
|
|
use std::time::Instant;
|
|
use futures::StreamExt;
|
|
use super::context::WireChunk;
|
|
use salience::pb;
|
|
|
|
let mut handle: salience::SessionHandle = {
|
|
let mut guard = session_lock.lock().await;
|
|
match guard.take() {
|
|
Some(h) => h,
|
|
None => {
|
|
drop(guard);
|
|
log::debug!(target: "grpc", "run_session_generate: opening new session");
|
|
salience::SessionHandle::open(client).await?
|
|
}
|
|
}
|
|
};
|
|
|
|
// Skip chunks already on the server. committed_len must land on
|
|
// a chunk boundary — every successful AppendImage / Generate
|
|
// advances committed_len by exactly one chunk's contribution,
|
|
// so straddling means divergence (client's AST was rewritten
|
|
// under us).
|
|
let mut acc: u32 = 0;
|
|
let mut delta_start = chunks.len();
|
|
for (i, chunk) in chunks.iter().enumerate() {
|
|
if acc == handle.committed_len {
|
|
delta_start = i;
|
|
break;
|
|
}
|
|
let len = match chunk {
|
|
WireChunk::Tokens(t) => t.len() as u32,
|
|
WireChunk::Image { known_expanded_len, .. } => *known_expanded_len,
|
|
};
|
|
if len == 0 {
|
|
anyhow::bail!(
|
|
"session divergence: chunk {} has unknown length but \
|
|
precedes committed_len {} (acc={})",
|
|
i, handle.committed_len, acc,
|
|
);
|
|
}
|
|
if acc + len > handle.committed_len {
|
|
anyhow::bail!(
|
|
"session divergence: chunk {} straddles committed_len \
|
|
(acc={}, len={}, committed={})",
|
|
i, acc, len, handle.committed_len,
|
|
);
|
|
}
|
|
acc += len;
|
|
}
|
|
if acc != handle.committed_len {
|
|
anyhow::bail!(
|
|
"session divergence: chunks sum to {} but committed_len is {}",
|
|
acc, handle.committed_len,
|
|
);
|
|
}
|
|
|
|
// Walk the delta: accumulate Tokens in `pending`; on Image,
|
|
// flush pending via prefill-only Generate then AppendImage.
|
|
let mut pending: Vec<u32> = Vec::new();
|
|
for chunk in &chunks[delta_start..] {
|
|
match chunk {
|
|
WireChunk::Tokens(t) => pending.extend_from_slice(t),
|
|
WireChunk::Image { bytes, mime, .. } => {
|
|
if !pending.is_empty() {
|
|
handle.prefill_only(std::mem::take(&mut pending)).await?;
|
|
}
|
|
let resp = handle
|
|
.append_image(bytes.clone(), mime.clone(), false)
|
|
.await?;
|
|
log::debug!(target: "grpc",
|
|
"AppendImage: N={} total_length={}",
|
|
resp.placeholder_count, resp.total_length);
|
|
let _ = tx.send(StreamToken::ImageAppended {
|
|
placeholder_count: resp.placeholder_count,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Final Generate: pending holds any trailing text; decode up to
|
|
// sampling.max_tokens. Request readouts on all decode positions
|
|
// via a catch-all range ending at u32::MAX — decode never
|
|
// reaches it.
|
|
let prompt_len_after_append = handle.committed_len + pending.len() as u32;
|
|
let readout_ranges = if readout_shape.is_some() {
|
|
vec![pb::PositionRange {
|
|
start: prompt_len_after_append,
|
|
end: u32::MAX,
|
|
}]
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
let req = pb::GenerateRequest {
|
|
session_id: handle.session_id.clone(),
|
|
append_tokens: pending,
|
|
offset: handle.committed_len,
|
|
truncating: false,
|
|
max_tokens: sampling.max_tokens,
|
|
logprobs_ranges: Vec::new(),
|
|
logprob_top_k: 0,
|
|
readout_ranges,
|
|
temperature: sampling.temperature,
|
|
top_p: sampling.top_p,
|
|
top_k: sampling.top_k,
|
|
stop_token_ids: Vec::new(),
|
|
priority: priority.unwrap_or(0),
|
|
};
|
|
let session_id_for_log = handle.session_id.clone();
|
|
let t_generate = Instant::now();
|
|
log::debug!(target: "grpc",
|
|
"session {} Generate: offset={} append={} max_tokens={} priority={}",
|
|
session_id_for_log, req.offset, req.append_tokens.len(),
|
|
req.max_tokens, req.priority);
|
|
|
|
let mut stream = handle.generate(req).await?;
|
|
let (n_layers, n_concepts) = readout_shape.unwrap_or((0, 0));
|
|
let mut session_terminated = false;
|
|
let mut first_token_at: Option<Instant> = None;
|
|
|
|
while let Some(event) = stream.next().await {
|
|
let event = match event {
|
|
Ok(e) => e,
|
|
Err(status) => {
|
|
log::warn!(target: "grpc",
|
|
"session {} Generate stream error: {} — dropping session",
|
|
session_id_for_log, status);
|
|
session_terminated = true;
|
|
let _ = tx.send(StreamToken::Error(format!(
|
|
"Generate stream error: {}", status,
|
|
)));
|
|
break;
|
|
}
|
|
};
|
|
let Some(inner) = event.event else { continue };
|
|
match inner {
|
|
pb::generate_event::Event::Token(t) => {
|
|
if t.is_prefill { continue; }
|
|
if first_token_at.is_none() {
|
|
log::debug!(target: "grpc",
|
|
"session {} first decode token at {:?}",
|
|
session_id_for_log, t_generate.elapsed());
|
|
first_token_at = Some(Instant::now());
|
|
}
|
|
let readout = if t.readout.is_empty() {
|
|
None
|
|
} else if n_layers == 0 || n_concepts == 0 {
|
|
None
|
|
} else {
|
|
let expected = (n_layers as usize) * (n_concepts as usize);
|
|
if t.readout.len() != expected {
|
|
log::warn!(target: "grpc",
|
|
"readout shape mismatch: expected {}*{}={}, got {}",
|
|
n_layers, n_concepts, expected, t.readout.len());
|
|
None
|
|
} else {
|
|
let n = n_concepts as usize;
|
|
let mut layers: Vec<Vec<f32>> = Vec::with_capacity(n_layers as usize);
|
|
for l in 0..(n_layers as usize) {
|
|
layers.push(t.readout[l * n..(l + 1) * n].to_vec());
|
|
}
|
|
Some(layers)
|
|
}
|
|
};
|
|
if tx.send(StreamToken::Token { id: t.id, readout }).is_err() {
|
|
break;
|
|
}
|
|
}
|
|
pb::generate_event::Event::Done(d) => {
|
|
log::debug!(target: "grpc",
|
|
"session {} Done: prompt={} completion={} total={} reason={:?} elapsed={:?}",
|
|
session_id_for_log, d.prompt_tokens, d.completion_tokens,
|
|
d.total_tokens, d.finish_reason, t_generate.elapsed());
|
|
handle.committed_len = d.total_tokens;
|
|
let usage = Some(Usage {
|
|
prompt_tokens: d.prompt_tokens,
|
|
completion_tokens: d.completion_tokens,
|
|
total_tokens: d.total_tokens,
|
|
});
|
|
let _ = tx.send(StreamToken::Done { usage });
|
|
}
|
|
}
|
|
}
|
|
|
|
if !session_terminated {
|
|
let mut guard = session_lock.lock().await;
|
|
*guard = Some(handle);
|
|
}
|
|
Ok(())
|
|
}
|
|
|