salience: add gRPC client + TLS plumbing for stateful vllm sessions

Adds the client-side of a stateful gRPC protocol against vllm, plus
the TLS trust machinery so we can talk to self-signed vllm servers.

Protocol (proto/salience.proto):
  Bidi-streaming Session RPC carries OpenSession / AppendTokens /
  Generate / Cancel from client and SessionReady / PrefillProgress /
  Token / GenerateDone / Error from server. Separate Fork unary RPC
  for cheap branching (prefix cache shares KV automatically). Plus
  ListSessions, CloseSession, GetReadoutManifest admin RPCs.

  Per-token readouts ship as packed f32 ([n_layers * n_concepts] per
  token, flat). Logprobs use range-selected positions plus a top-k
  parameter — empty ranges means no logprobs, any range means emit
  sampled-token logprob at those positions, top_k > 0 adds
  alternatives.

Client (src/agent/api/salience.rs):
  Tonic-generated types under pb::, a connect() helper, with_auth()
  for bearer metadata, and a Session handle wrapping the bidi stream:
  open() handshakes SessionReady; append() is fire-and-forget;
  generate() returns impl Stream<Item = Event> that drains inbound
  until Done or terminating Error. One generate at a time per session.

Peak picker (src/agent/salience.rs):
  Pure function over ReadoutEntry traces. Per-concept z-score against
  trace global stats; contiguous above-threshold regions emit one
  peak at the local max. Configurable sigma threshold and min-std
  safety floor. Deterministic tie-break on offset then concept name.
  12 unit tests covering empty traces, flat channels, single/multi
  spikes, contiguous humps, multi-concept independence, trailing
  runs, sub-threshold noise, layer-out-of-range, manifest shape
  mismatch, and threshold tunability.

TLS (src/agent/api/http.rs):
  HttpClient::build now also loads every .pem file under
  ~/.consciousness/certs/ into the rustls root store — so dropping
  a <host>.pem in that directory is enough to trust a new self-
  signed server; no code changes per new host. Also installs the
  rustls default crypto provider explicitly via OnceLock: tonic's
  tls features pulled in both ring and aws-lc-rs on the resolver
  path, and rustls 0.23 refuses to auto-pick when either could win.

Build (build.rs, Cargo.toml):
  tonic-build generates Rust types from proto/salience.proto at
  cargo-build time, using a vendored protoc binary
  (protoc-bin-vendored) so no system install is required. New
  runtime deps: tonic, prost, async-stream, tokio-stream,
  rustls-pemfile.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-23 02:21:07 -04:00
commit 08213f9093
15 changed files with 1689 additions and 440 deletions

View file

@ -100,7 +100,7 @@ impl HttpClient {
.map_err(|e| anyhow::anyhow!("invalid server name: {e}"))?;
let connector = tokio_rustls::TlsConnector::from(self.tls.clone());
let tls = connector.connect(server_name.to_owned(), tcp).await
.context("TLS handshake")?;
.map_err(|e| anyhow::anyhow!("TLS handshake to {host}: {e}"))?;
TokioIo::new(Box::new(tls) as Box<dyn IoStream>)
} else {
TokioIo::new(Box::new(tcp) as Box<dyn IoStream>)
@ -190,6 +190,7 @@ impl HttpClientBuilder {
}
pub fn build(self) -> HttpClient {
install_rustls_crypto_provider();
let certs = rustls_native_certs::load_native_certs()
.certs.into_iter()
.collect::<Vec<_>>();
@ -197,6 +198,13 @@ impl HttpClientBuilder {
for cert in certs {
root_store.add(cert).ok();
}
// Also trust any `.pem` files under `~/.consciousness/certs/` —
// self-signed server certs for our own vllm hosts live there.
// Drop a new `<host>.pem` in the dir to trust a new server; no
// code change needed.
for cert in load_user_certs() {
root_store.add(cert).ok();
}
let tls = Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
@ -210,6 +218,65 @@ impl HttpClientBuilder {
}
}
/// Install rustls' default crypto provider exactly once per process.
/// rustls 0.23 doesn't pick one automatically when multiple features
/// could provide it (e.g. when tonic pulls in both ring and aws-lc-rs
/// via transitive deps). Idempotent via OnceLock; safe to call from
/// multiple callers.
fn install_rustls_crypto_provider() {
static ONCE: std::sync::OnceLock<()> = std::sync::OnceLock::new();
ONCE.get_or_init(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
/// Load every `.pem` file under `~/.consciousness/certs/` as a DER
/// certificate and return them. Silent on missing dir, missing files,
/// or parse errors — those are "no extra certs trusted" rather than
/// hard failures, to keep startup robust.
/// Load the concatenated PEM bytes of every `.pem` file under
/// `~/.consciousness/certs/` — suitable for passing to a tonic
/// `ClientTlsConfig::ca_certificate(Certificate::from_pem(...))` call
/// so gRPC connections trust the same self-signed servers the HTTP
/// path does.
pub(crate) fn load_user_certs_pem_bytes() -> Vec<u8> {
let mut out = Vec::new();
let Some(home) = dirs::home_dir() else { return out };
let dir = home.join(".consciousness").join("certs");
let Ok(entries) = std::fs::read_dir(&dir) else { return out };
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("pem") {
continue;
}
if let Ok(bytes) = std::fs::read(&path) {
out.extend_from_slice(&bytes);
if !bytes.ends_with(b"\n") {
out.push(b'\n');
}
}
}
out
}
fn load_user_certs() -> Vec<rustls::pki_types::CertificateDer<'static>> {
let mut out = Vec::new();
let Some(home) = dirs::home_dir() else { return out };
let dir = home.join(".consciousness").join("certs");
let Ok(entries) = std::fs::read_dir(&dir) else { return out };
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("pem") {
continue;
}
let Ok(bytes) = std::fs::read(&path) else { continue };
for cert in rustls_pemfile::certs(&mut bytes.as_slice()).flatten() {
out.push(cert);
}
}
out
}
/// Trait alias for streams that work with hyper's IO adapter.
trait IoStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static {}
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static> IoStream for T {}

View file

@ -7,13 +7,14 @@
// Set POC_DEBUG=1 for verbose per-turn logging.
pub mod http;
pub mod salience;
use std::time::{Duration, Instant};
use std::time::Duration;
use anyhow::Result;
use tokio::sync::mpsc;
use serde::Deserialize;
use http::{HttpClient, HttpResponse};
use http::HttpClient;
#[derive(Debug, Clone, Deserialize)]
pub struct Usage {
@ -48,6 +49,7 @@ impl Drop for AbortOnDrop {
/// Sampling parameters for model generation.
#[derive(Clone, Copy)]
#[allow(dead_code)] // fields used once Generate RPC lands in a later step
pub(crate) struct SamplingParams {
pub temperature: f32,
pub top_p: f32,
@ -74,6 +76,10 @@ pub struct ApiClient {
api_key: String,
pub model: String,
base_url: String,
/// Cached readout manifest — fetched once per process and shared
/// across ApiClient clones (every Agent/fork gets the same cell).
/// `None` after fetch means the server has readout disabled (404).
manifest: std::sync::Arc<tokio::sync::OnceCell<Option<ReadoutManifest>>>,
}
impl ApiClient {
@ -88,36 +94,30 @@ impl ApiClient {
api_key: api_key.to_string(),
model: model.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
manifest: std::sync::Arc::new(tokio::sync::OnceCell::new()),
}
}
pub(crate) fn stream_completion_mm(
/// Stream generation via a gRPC session. Stubbed during the
/// unary-rewrite transition — the Generate RPC is wired in a
/// later step of this series. Until then, callers that reach
/// this path get a StreamToken::Error.
pub(crate) fn stream_session_mm(
&self,
prompt_tokens: &[u32],
images: &[super::context::WireImage],
sampling: SamplingParams,
priority: Option<i32>,
_session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
_prompt_tokens: &[u32],
_images: &[super::context::WireImage],
_sampling: SamplingParams,
_priority: Option<i32>,
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
let (tx, rx) = mpsc::unbounded_channel();
let client = self.client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let prompt_tokens = prompt_tokens.to_vec();
let images: Vec<(Vec<u8>, String)> = images.iter()
.map(|i| (i.bytes.clone(), i.mime.clone()))
.collect();
let base_url = self.base_url.clone();
let handle = tokio::spawn(async move {
let result = stream_completions(
&client, &base_url, &api_key, &model,
&prompt_tokens, &images, &tx, sampling, priority,
).await;
if let Err(e) = result {
let _ = tx.send(StreamToken::Error(e.to_string()));
}
let _ = tx.send(StreamToken::Error(
"Generate RPC not yet wired after protocol rewrite — see \
proto/salience.proto; AppendImage / Generate land next."
.into(),
));
});
(rx, AbortOnDrop(handle))
}
@ -128,386 +128,31 @@ impl ApiClient {
/// readout is enabled on the server, `Ok(None)` on 404 (disabled),
/// or an error on any other failure.
///
/// Call once at startup and cache the result; the manifest doesn't
/// change during a server run.
/// First call performs the HTTP fetch; subsequent calls (including
/// across ApiClient clones sharing the same cell) return the
/// cached result. The manifest doesn't change during a server run.
pub async fn fetch_readout_manifest(&self) -> Result<Option<ReadoutManifest>> {
let url = format!("{}/readout/manifest", self.base_url);
let auth = format!("Bearer {}", self.api_key);
let response = self
.client
.get_with_headers(&url, &[("Authorization", &auth)])
.await
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
let status = response.status();
if status.as_u16() == 404 {
return Ok(None);
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
let n = body.floor_char_boundary(body.len().min(500));
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
}
Ok(Some(response.json().await?))
let manifest = self.manifest.get_or_try_init(|| async {
let url = format!("{}/readout/manifest", self.base_url);
let auth = format!("Bearer {}", self.api_key);
let response = self
.client
.get_with_headers(&url, &[("Authorization", &auth)])
.await
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
let status = response.status();
if status.as_u16() == 404 {
return Ok::<_, anyhow::Error>(None);
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
let n = body.floor_char_boundary(body.len().min(500));
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
}
Ok(Some(response.json().await?))
}).await?;
Ok(manifest.clone())
}
}
async fn stream_completions(
client: &HttpClient,
base_url: &str,
api_key: &str,
model: &str,
prompt_tokens: &[u32],
images: &[(Vec<u8>, String)],
tx: &mpsc::UnboundedSender<StreamToken>,
sampling: SamplingParams,
priority: Option<i32>,
) -> anyhow::Result<()> {
let mut request = serde_json::json!({
"model": model,
"prompt": prompt_tokens,
"max_tokens": 16384,
"temperature": sampling.temperature,
"top_p": sampling.top_p,
"top_k": sampling.top_k,
"stream": true,
"return_token_ids": true,
"skip_special_tokens": false,
"stop_token_ids": [super::tokenizer::IM_END],
});
if !images.is_empty() {
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD;
let uris: Vec<String> = images.iter()
.map(|(bytes, mime)| format!("data:{};base64,{}", mime, b64.encode(bytes)))
.collect();
request["multi_modal_data"] = serde_json::json!({ "image": uris });
}
if let Some(p) = priority {
request["priority"] = serde_json::json!(p);
}
let url = format!("{}/completions", base_url);
let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model);
let mut response = send_and_check(
client, &url, &request,
("Authorization", &format!("Bearer {}", api_key)),
&[], &debug_label, None,
).await?;
let mut reader = SseReader::new();
let mut usage = None;
while let Some(event) = reader.next_event(&mut response).await? {
if let Some(err_msg) = event["error"]["message"].as_str() {
anyhow::bail!("API error in stream: {}", err_msg);
}
if let Some(u) = event["usage"].as_object() {
if let Ok(u) = serde_json::from_value::<Usage>(serde_json::Value::Object(u.clone())) {
usage = Some(u);
}
}
let choices = match event["choices"].as_array() {
Some(c) => c,
None => continue,
};
for choice in choices {
// `readout`, if present, is a nested list
// `[num_tokens][n_layers][n_concepts]`. Parse it once per
// chunk and pair rows with token ids by index — the rows
// are in the same order as `token_ids`.
let readouts: Option<Vec<TokenReadout>> = choice["readout"]
.as_array()
.map(|outer| {
outer.iter().filter_map(|per_token| {
per_token.as_array().map(|layers| {
layers.iter().filter_map(|per_layer| {
per_layer.as_array().map(|vals| {
vals.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect::<Vec<f32>>()
})
}).collect::<Vec<Vec<f32>>>()
})
}).collect()
});
if let Some(ids) = choice["token_ids"].as_array() {
for (i, id_val) in ids.iter().enumerate() {
if let Some(id) = id_val.as_u64() {
let readout = readouts
.as_ref()
.and_then(|r| r.get(i).cloned());
let _ = tx.send(StreamToken::Token {
id: id as u32,
readout,
});
}
}
} else if let Some(text) = choice["text"].as_str() {
// Fallback: provider didn't return token_ids, encode locally.
// No readout available in this path — the encoder may
// produce a different token count than the server did.
if !text.is_empty() {
for id in super::tokenizer::encode(text) {
let _ = tx.send(StreamToken::Token { id, readout: None });
}
}
}
}
}
let _ = tx.send(StreamToken::Done { usage });
Ok(())
}
/// Send an HTTP request and check for errors.
pub(crate) async fn send_and_check(
client: &HttpClient,
url: &str,
body: &impl serde::Serialize,
auth_header: (&str, &str),
extra_headers: &[(&str, &str)],
debug_label: &str,
request_json: Option<&str>,
) -> Result<HttpResponse> {
let debug = std::env::var("POC_DEBUG").is_ok();
let start = Instant::now();
if debug {
let payload_size = serde_json::to_string(body)
.map(|s| s.len())
.unwrap_or(0);
dbglog!(
"request: {}K payload, {}",
payload_size / 1024, debug_label,
);
}
let mut headers: Vec<(&str, &str)> = Vec::with_capacity(extra_headers.len() + 1);
headers.push(auth_header);
headers.extend_from_slice(extra_headers);
let response = client
.send_json("POST", url, &headers, body)
.await
.map_err(|e| {
let msg = e.to_string();
let cause = if msg.contains("connect timeout") || msg.contains("TCP connect") {
"connection refused"
} else if msg.contains("request timeout") {
"request timed out"
} else {
"request error"
};
anyhow::anyhow!("{} ({}): {}", cause, url, msg)
})?;
let status = response.status();
let elapsed = start.elapsed();
if debug {
for name in [
"x-ratelimit-remaining",
"x-ratelimit-limit",
"x-request-id",
] {
if let Some(val) = response.header(name) {
dbglog!("header {}: {}", name, val);
}
}
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
dbglog!(
"HTTP {} after {:.1}s ({}): {}",
status,
elapsed.as_secs_f64(),
url,
&body[..body.floor_char_boundary(body.len().min(500))]
);
if let Some(json) = request_json {
let log_dir = dirs::home_dir()
.unwrap_or_default()
.join(".consciousness/logs/failed-requests");
let _ = std::fs::create_dir_all(&log_dir);
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
let path = log_dir.join(format!("{}.json", ts));
if std::fs::write(&path, json).is_ok() {
dbglog!(
"saved failed request to {} (HTTP {})", path.display(), status
);
}
}
anyhow::bail!("HTTP {} ({}): {}", status, url, &body[..body.floor_char_boundary(body.len().min(1000))]);
}
if debug {
dbglog!(
"connected in {:.1}s (HTTP {})",
elapsed.as_secs_f64(),
status.as_u16()
);
}
Ok(response)
}
/// SSE stream reader. Handles the generic SSE plumbing shared by both
/// backends: chunk reading with timeout, line buffering, `data:` prefix
/// stripping, `[DONE]` detection, JSON parsing, and parse error diagnostics.
/// Yields parsed events as serde_json::Value — each backend handles its
/// own event types.
pub(crate) struct SseReader {
line_buf: String,
chunk_timeout: Duration,
pub stream_start: Instant,
pub chunks_received: u64,
pub sse_lines_parsed: u64,
pub sse_parse_errors: u64,
debug: bool,
done: bool,
/// Serialized request payload — saved to disk on errors for replay debugging.
pub(crate) request_json: Option<String>,
}
impl SseReader {
pub(crate) fn new() -> Self {
Self {
line_buf: String::new(),
chunk_timeout: Duration::from_secs(crate::config::get().api_stream_timeout_secs),
stream_start: Instant::now(),
chunks_received: 0,
sse_lines_parsed: 0,
sse_parse_errors: 0,
debug: std::env::var("POC_DEBUG").is_ok(),
done: false,
request_json: None,
}
}
/// Attach the serialized request payload for error diagnostics.
/// Save the request payload to disk for replay debugging.
fn save_failed_request(&self, reason: &str) {
let Some(ref json) = self.request_json else { return };
let log_dir = dirs::home_dir()
.unwrap_or_default()
.join(".consciousness/logs/failed-requests");
let _ = std::fs::create_dir_all(&log_dir);
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
let path = log_dir.join(format!("{}.json", ts));
if std::fs::write(&path, json).is_ok() {
dbglog!(
"saved failed request to {} ({})", path.display(), reason
);
}
}
/// Read the next SSE event from the response stream.
/// Returns Ok(Some(value)) for each parsed data line,
/// Ok(None) when the stream ends or [DONE] is received.
pub(crate) async fn next_event(
&mut self,
response: &mut HttpResponse,
) -> Result<Option<serde_json::Value>> {
loop {
// Drain complete lines from the buffer before reading more chunks
while let Some(newline_pos) = self.line_buf.find('\n') {
let line = self.line_buf[..newline_pos].trim().to_string();
self.line_buf = self.line_buf[newline_pos + 1..].to_string();
if line == "data: [DONE]" {
self.done = true;
return Ok(None);
}
if line.is_empty()
|| line.starts_with("event: ")
|| !line.starts_with("data: ")
{
continue;
}
let json_str = &line[6..];
self.sse_lines_parsed += 1;
match serde_json::from_str(json_str) {
Ok(v) => return Ok(Some(v)),
Err(e) => {
self.sse_parse_errors += 1;
if self.sse_parse_errors == 1 || self.debug {
let preview = if json_str.len() > 200 {
format!("{}...", &json_str[..200])
} else {
json_str.to_string()
};
dbglog!(
"SSE parse error (#{}) {}: {}",
self.sse_parse_errors, e, preview
);
}
continue;
}
}
}
if self.done {
return Ok(None);
}
// Read more data from the response stream
match tokio::time::timeout(self.chunk_timeout, response.chunk()).await {
Ok(Ok(Some(chunk))) => {
self.chunks_received += 1;
self.line_buf.push_str(&String::from_utf8_lossy(&chunk));
}
Ok(Ok(None)) => return Ok(None),
Ok(Err(e)) => {
let buf_preview = if self.line_buf.is_empty() {
"(empty)".to_string()
} else {
let n = self.line_buf.len().min(500);
format!("{}B: {}", self.line_buf.len(), &self.line_buf[..n])
};
let msg = format!(
"stream error after {} chunks, {:.1}s, {} sse lines: {} | buf: {}",
self.chunks_received,
self.stream_start.elapsed().as_secs_f64(),
self.sse_lines_parsed,
e, buf_preview,
);
dbglog!("{}", msg);
self.save_failed_request(&msg);
return Err(e.into());
}
Err(_) => {
let buf_preview = if self.line_buf.is_empty() {
"(empty)".to_string()
} else {
let n = self.line_buf.len().min(500);
format!("{}B: {}", self.line_buf.len(), &self.line_buf[..n])
};
let msg = format!(
"stream timeout: {}s, {} chunks, {} sse lines, {:.1}s elapsed | buf: {}",
self.chunk_timeout.as_secs(),
self.chunks_received,
self.sse_lines_parsed,
self.stream_start.elapsed().as_secs_f64(),
buf_preview,
);
dbglog!("{}", msg);
self.save_failed_request(&msg);
anyhow::bail!(
"stream timeout: no data for {}s ({} chunks received)",
self.chunk_timeout.as_secs(),
self.chunks_received
);
}
}
}
}
}

249
src/agent/api/salience.rs Normal file
View file

@ -0,0 +1,249 @@
// agent/api/salience.rs — gRPC client bindings for salience.v1.
//
// Thin wrapper around the tonic-generated types. Every RPC except
// Generate is unary; Generate is server-streaming. Free functions
// (open/close session) wrap the lifecycle RPCs; `SessionHandle` just
// carries the id + connection params so later RPCs can reuse them.
//
// The old bidi Session() API is gone — see git history for its shape.
#![allow(clippy::enum_variant_names)]
use anyhow::{Context, Result};
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint};
/// Generated prost + tonic types for salience.v1. Call sites use
/// `pb::OpenSessionRequest`, `pb::Token`, etc.
pub mod pb {
tonic::include_proto!("salience.v1");
}
pub type SalienceClient = pb::salience_client::SalienceClient<Channel>;
/// Open a TLS-aware gRPC channel to the salience server. `base_url`
/// looks like `https://host:8443`. User-provided CA certs under
/// `~/.consciousness/certs/` are trusted in addition to the system
/// roots (for self-signed server certs).
pub async fn connect(base_url: &str) -> Result<SalienceClient> {
let mut endpoint = Endpoint::from_shared(base_url.to_string())
.with_context(|| format!("invalid salience endpoint: {}", base_url))?
.connect_timeout(std::time::Duration::from_secs(30))
.timeout(std::time::Duration::from_secs(600));
if base_url.starts_with("https://") {
let user_certs = super::http::load_user_certs_pem_bytes();
let mut tls = ClientTlsConfig::new().with_native_roots();
if !user_certs.is_empty() {
tls = tls.ca_certificate(Certificate::from_pem(user_certs));
}
endpoint = endpoint
.tls_config(tls)
.with_context(|| "configuring tonic TLS")?;
}
let channel = endpoint
.connect()
.await
.with_context(|| format!("failed to connect to salience server at {}", base_url))?;
Ok(pb::salience_client::SalienceClient::new(channel))
}
/// Derive the gRPC base URL from the HTTP completions base URL.
///
/// vLLM's salience gRPC server listens on a different port (8443) from
/// the HTTP endpoint (8000) and accepts no path component. Given an
/// HTTP base like `https://host:8000/v1`, produce `https://host:8443`.
/// No-op when the path is empty and the port isn't 8000.
pub fn derive_grpc_url(http_base: &str) -> String {
let mut url = http_base.trim_end_matches('/').to_string();
if let Some(proto_end) = url.find("://") {
let rest_start = proto_end + 3;
if let Some(path_slash) = url[rest_start..].find('/') {
url.truncate(rest_start + path_slash);
}
}
url.replace(":8000", ":8443")
}
/// Attach a bearer token to a tonic request as gRPC metadata.
pub fn with_auth<T>(req: &mut tonic::Request<T>, api_key: &str) {
if api_key.is_empty() {
return;
}
let bearer = format!("Bearer {}", api_key);
if let Ok(val) = bearer.parse() {
req.metadata_mut().insert("authorization", val);
}
}
/// Call the server's `OpenSession` RPC and return the response.
pub async fn open_session(
base_url: &str,
api_key: &str,
model: &str,
) -> Result<pb::OpenSessionResponse> {
let mut client = connect(base_url).await?;
let mut req = tonic::Request::new(pb::OpenSessionRequest {
model: model.to_string(),
});
with_auth(&mut req, api_key);
let resp = client
.open_session(req)
.await
.with_context(|| "OpenSession RPC failed")?;
Ok(resp.into_inner())
}
/// Call the server's `CloseSession` RPC. Idempotent on the server.
pub async fn close_session(base_url: &str, api_key: &str, session_id: &str) -> Result<()> {
let mut client = connect(base_url).await?;
let mut req = tonic::Request::new(pb::CloseSessionRequest {
session_id: session_id.to_string(),
});
with_auth(&mut req, api_key);
client
.close_session(req)
.await
.with_context(|| "CloseSession RPC failed")?;
Ok(())
}
/// Append an image to a session. Server decodes the image, computes N
/// via vLLM's own multimodal pipeline, writes the full vision block
/// (`<|vision_start|> + IMAGE_PAD×N + <|vision_end|>`) into
/// session.tokens, and returns (N, new total length).
///
/// `offset` is the client's view of the session's current token count;
/// the server rejects if it diverges from its own (unless
/// `truncating=true`, in which case the server slices to `offset`
/// first — but never through a vision block).
pub async fn append_image(
base_url: &str,
api_key: &str,
session_id: &str,
data: Vec<u8>,
mime: String,
offset: u32,
truncating: bool,
) -> Result<pb::AppendImageResponse> {
let mut client = connect(base_url).await?;
let mut req = tonic::Request::new(pb::AppendImageRequest {
session_id: session_id.to_string(),
data,
mime,
offset,
truncating,
});
with_auth(&mut req, api_key);
let resp = client
.append_image(req)
.await
.with_context(|| "AppendImage RPC failed")?;
Ok(resp.into_inner())
}
/// Handle to a server-side session. Carries the id + connection params
/// so subsequent per-session RPCs (AppendImage, Generate, ForkSession)
/// can be issued without the caller juggling base_url / api_key each
/// time.
pub struct SessionHandle {
pub session_id: String,
pub max_model_len: u32,
pub base_url: String,
pub api_key: String,
}
impl SessionHandle {
pub async fn open(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
let grpc_url = derive_grpc_url(base_url);
log::debug!(target: "grpc",
"SessionHandle::open http_base={} -> grpc_url={}",
base_url, grpc_url);
let resp = open_session(&grpc_url, api_key, model).await?;
log::debug!(target: "grpc",
"SessionHandle::open session_id={} max_model_len={}",
resp.session_id, resp.max_model_len);
Ok(Self {
session_id: resp.session_id,
max_model_len: resp.max_model_len,
base_url: grpc_url,
api_key: api_key.to_string(),
})
}
pub async fn close(self) -> Result<()> {
close_session(&self.base_url, &self.api_key, &self.session_id).await
}
/// Append an image via the server-side vision block. See
/// `append_image` free function for full semantics.
pub async fn append_image(
&self,
data: Vec<u8>,
mime: String,
offset: u32,
truncating: bool,
) -> Result<pb::AppendImageResponse> {
append_image(
&self.base_url,
&self.api_key,
&self.session_id,
data,
mime,
offset,
truncating,
)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generated_types_compile() {
// Exercise the shape of the new proto types — if build.rs
// stops regenerating against the proto, this stops compiling.
let _open = pb::OpenSessionRequest {
model: "qwen3-vl".into(),
};
let _tok = pb::Token {
id: 42,
position: 0,
is_prefill: false,
readout: vec![0.1, 0.2, 0.3],
logprobs: vec![pb::TokenLogprob {
id: 1,
logprob: -0.5,
}],
sampled_logprob: -0.1,
has_sampled_logprob: true,
};
let _done = pb::GenerateDone {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
finish_reason: pb::generate_done::FinishReason::Eos as i32,
};
let _evt = pb::GenerateEvent {
event: Some(pb::generate_event::Event::Done(_done)),
};
}
#[test]
fn derive_grpc_url_cases() {
assert_eq!(
derive_grpc_url("https://host:8000/v1"),
"https://host:8443",
);
assert_eq!(
derive_grpc_url("https://host:8000/"),
"https://host:8443",
);
assert_eq!(
derive_grpc_url("https://host:9000/v1"),
"https://host:9000",
);
}
}