forked from kent/consciousness
salience: add gRPC client + TLS plumbing for stateful vllm sessions
Adds the client-side of a stateful gRPC protocol against vllm, plus the TLS trust machinery so we can talk to self-signed vllm servers. Protocol (proto/salience.proto): Bidi-streaming Session RPC carries OpenSession / AppendTokens / Generate / Cancel from client and SessionReady / PrefillProgress / Token / GenerateDone / Error from server. Separate Fork unary RPC for cheap branching (prefix cache shares KV automatically). Plus ListSessions, CloseSession, GetReadoutManifest admin RPCs. Per-token readouts ship as packed f32 ([n_layers * n_concepts] per token, flat). Logprobs use range-selected positions plus a top-k parameter — empty ranges means no logprobs, any range means emit sampled-token logprob at those positions, top_k > 0 adds alternatives. Client (src/agent/api/salience.rs): Tonic-generated types under pb::, a connect() helper, with_auth() for bearer metadata, and a Session handle wrapping the bidi stream: open() handshakes SessionReady; append() is fire-and-forget; generate() returns impl Stream<Item = Event> that drains inbound until Done or terminating Error. One generate at a time per session. Peak picker (src/agent/salience.rs): Pure function over ReadoutEntry traces. Per-concept z-score against trace global stats; contiguous above-threshold regions emit one peak at the local max. Configurable sigma threshold and min-std safety floor. Deterministic tie-break on offset then concept name. 12 unit tests covering empty traces, flat channels, single/multi spikes, contiguous humps, multi-concept independence, trailing runs, sub-threshold noise, layer-out-of-range, manifest shape mismatch, and threshold tunability. TLS (src/agent/api/http.rs): HttpClient::build now also loads every .pem file under ~/.consciousness/certs/ into the rustls root store — so dropping a <host>.pem in that directory is enough to trust a new self- signed server; no code changes per new host. Also installs the rustls default crypto provider explicitly via OnceLock: tonic's tls features pulled in both ring and aws-lc-rs on the resolver path, and rustls 0.23 refuses to auto-pick when either could win. Build (build.rs, Cargo.toml): tonic-build generates Rust types from proto/salience.proto at cargo-build time, using a vendored protoc binary (protoc-bin-vendored) so no system install is required. New runtime deps: tonic, prost, async-stream, tokio-stream, rustls-pemfile. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
0e459aae92
commit
08213f9093
15 changed files with 1689 additions and 440 deletions
|
|
@ -100,7 +100,7 @@ impl HttpClient {
|
|||
.map_err(|e| anyhow::anyhow!("invalid server name: {e}"))?;
|
||||
let connector = tokio_rustls::TlsConnector::from(self.tls.clone());
|
||||
let tls = connector.connect(server_name.to_owned(), tcp).await
|
||||
.context("TLS handshake")?;
|
||||
.map_err(|e| anyhow::anyhow!("TLS handshake to {host}: {e}"))?;
|
||||
TokioIo::new(Box::new(tls) as Box<dyn IoStream>)
|
||||
} else {
|
||||
TokioIo::new(Box::new(tcp) as Box<dyn IoStream>)
|
||||
|
|
@ -190,6 +190,7 @@ impl HttpClientBuilder {
|
|||
}
|
||||
|
||||
pub fn build(self) -> HttpClient {
|
||||
install_rustls_crypto_provider();
|
||||
let certs = rustls_native_certs::load_native_certs()
|
||||
.certs.into_iter()
|
||||
.collect::<Vec<_>>();
|
||||
|
|
@ -197,6 +198,13 @@ impl HttpClientBuilder {
|
|||
for cert in certs {
|
||||
root_store.add(cert).ok();
|
||||
}
|
||||
// Also trust any `.pem` files under `~/.consciousness/certs/` —
|
||||
// self-signed server certs for our own vllm hosts live there.
|
||||
// Drop a new `<host>.pem` in the dir to trust a new server; no
|
||||
// code change needed.
|
||||
for cert in load_user_certs() {
|
||||
root_store.add(cert).ok();
|
||||
}
|
||||
let tls = Arc::new(
|
||||
ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
|
|
@ -210,6 +218,65 @@ impl HttpClientBuilder {
|
|||
}
|
||||
}
|
||||
|
||||
/// Install rustls' default crypto provider exactly once per process.
|
||||
/// rustls 0.23 doesn't pick one automatically when multiple features
|
||||
/// could provide it (e.g. when tonic pulls in both ring and aws-lc-rs
|
||||
/// via transitive deps). Idempotent via OnceLock; safe to call from
|
||||
/// multiple callers.
|
||||
fn install_rustls_crypto_provider() {
|
||||
static ONCE: std::sync::OnceLock<()> = std::sync::OnceLock::new();
|
||||
ONCE.get_or_init(|| {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
});
|
||||
}
|
||||
|
||||
/// Load every `.pem` file under `~/.consciousness/certs/` as a DER
|
||||
/// certificate and return them. Silent on missing dir, missing files,
|
||||
/// or parse errors — those are "no extra certs trusted" rather than
|
||||
/// hard failures, to keep startup robust.
|
||||
/// Load the concatenated PEM bytes of every `.pem` file under
|
||||
/// `~/.consciousness/certs/` — suitable for passing to a tonic
|
||||
/// `ClientTlsConfig::ca_certificate(Certificate::from_pem(...))` call
|
||||
/// so gRPC connections trust the same self-signed servers the HTTP
|
||||
/// path does.
|
||||
pub(crate) fn load_user_certs_pem_bytes() -> Vec<u8> {
|
||||
let mut out = Vec::new();
|
||||
let Some(home) = dirs::home_dir() else { return out };
|
||||
let dir = home.join(".consciousness").join("certs");
|
||||
let Ok(entries) = std::fs::read_dir(&dir) else { return out };
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("pem") {
|
||||
continue;
|
||||
}
|
||||
if let Ok(bytes) = std::fs::read(&path) {
|
||||
out.extend_from_slice(&bytes);
|
||||
if !bytes.ends_with(b"\n") {
|
||||
out.push(b'\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn load_user_certs() -> Vec<rustls::pki_types::CertificateDer<'static>> {
|
||||
let mut out = Vec::new();
|
||||
let Some(home) = dirs::home_dir() else { return out };
|
||||
let dir = home.join(".consciousness").join("certs");
|
||||
let Ok(entries) = std::fs::read_dir(&dir) else { return out };
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("pem") {
|
||||
continue;
|
||||
}
|
||||
let Ok(bytes) = std::fs::read(&path) else { continue };
|
||||
for cert in rustls_pemfile::certs(&mut bytes.as_slice()).flatten() {
|
||||
out.push(cert);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Trait alias for streams that work with hyper's IO adapter.
|
||||
trait IoStream: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static {}
|
||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static> IoStream for T {}
|
||||
|
|
|
|||
|
|
@ -7,13 +7,14 @@
|
|||
// Set POC_DEBUG=1 for verbose per-turn logging.
|
||||
|
||||
pub mod http;
|
||||
pub mod salience;
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::Duration;
|
||||
use anyhow::Result;
|
||||
use tokio::sync::mpsc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use http::{HttpClient, HttpResponse};
|
||||
use http::HttpClient;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Usage {
|
||||
|
|
@ -48,6 +49,7 @@ impl Drop for AbortOnDrop {
|
|||
|
||||
/// Sampling parameters for model generation.
|
||||
#[derive(Clone, Copy)]
|
||||
#[allow(dead_code)] // fields used once Generate RPC lands in a later step
|
||||
pub(crate) struct SamplingParams {
|
||||
pub temperature: f32,
|
||||
pub top_p: f32,
|
||||
|
|
@ -74,6 +76,10 @@ pub struct ApiClient {
|
|||
api_key: String,
|
||||
pub model: String,
|
||||
base_url: String,
|
||||
/// Cached readout manifest — fetched once per process and shared
|
||||
/// across ApiClient clones (every Agent/fork gets the same cell).
|
||||
/// `None` after fetch means the server has readout disabled (404).
|
||||
manifest: std::sync::Arc<tokio::sync::OnceCell<Option<ReadoutManifest>>>,
|
||||
}
|
||||
|
||||
impl ApiClient {
|
||||
|
|
@ -88,36 +94,30 @@ impl ApiClient {
|
|||
api_key: api_key.to_string(),
|
||||
model: model.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
manifest: std::sync::Arc::new(tokio::sync::OnceCell::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn stream_completion_mm(
|
||||
/// Stream generation via a gRPC session. Stubbed during the
|
||||
/// unary-rewrite transition — the Generate RPC is wired in a
|
||||
/// later step of this series. Until then, callers that reach
|
||||
/// this path get a StreamToken::Error.
|
||||
pub(crate) fn stream_session_mm(
|
||||
&self,
|
||||
prompt_tokens: &[u32],
|
||||
images: &[super::context::WireImage],
|
||||
sampling: SamplingParams,
|
||||
priority: Option<i32>,
|
||||
_session_lock: std::sync::Arc<crate::Mutex<Option<salience::SessionHandle>>>,
|
||||
_prompt_tokens: &[u32],
|
||||
_images: &[super::context::WireImage],
|
||||
_sampling: SamplingParams,
|
||||
_priority: Option<i32>,
|
||||
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let client = self.client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let model = self.model.clone();
|
||||
let prompt_tokens = prompt_tokens.to_vec();
|
||||
let images: Vec<(Vec<u8>, String)> = images.iter()
|
||||
.map(|i| (i.bytes.clone(), i.mime.clone()))
|
||||
.collect();
|
||||
let base_url = self.base_url.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = stream_completions(
|
||||
&client, &base_url, &api_key, &model,
|
||||
&prompt_tokens, &images, &tx, sampling, priority,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
let _ = tx.send(StreamToken::Error(e.to_string()));
|
||||
}
|
||||
let _ = tx.send(StreamToken::Error(
|
||||
"Generate RPC not yet wired after protocol rewrite — see \
|
||||
proto/salience.proto; AppendImage / Generate land next."
|
||||
.into(),
|
||||
));
|
||||
});
|
||||
|
||||
(rx, AbortOnDrop(handle))
|
||||
}
|
||||
|
||||
|
|
@ -128,386 +128,31 @@ impl ApiClient {
|
|||
/// readout is enabled on the server, `Ok(None)` on 404 (disabled),
|
||||
/// or an error on any other failure.
|
||||
///
|
||||
/// Call once at startup and cache the result; the manifest doesn't
|
||||
/// change during a server run.
|
||||
/// First call performs the HTTP fetch; subsequent calls (including
|
||||
/// across ApiClient clones sharing the same cell) return the
|
||||
/// cached result. The manifest doesn't change during a server run.
|
||||
pub async fn fetch_readout_manifest(&self) -> Result<Option<ReadoutManifest>> {
|
||||
let url = format!("{}/readout/manifest", self.base_url);
|
||||
let auth = format!("Bearer {}", self.api_key);
|
||||
let response = self
|
||||
.client
|
||||
.get_with_headers(&url, &[("Authorization", &auth)])
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
|
||||
let status = response.status();
|
||||
if status.as_u16() == 404 {
|
||||
return Ok(None);
|
||||
}
|
||||
if !status.is_success() {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let n = body.floor_char_boundary(body.len().min(500));
|
||||
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
|
||||
}
|
||||
Ok(Some(response.json().await?))
|
||||
let manifest = self.manifest.get_or_try_init(|| async {
|
||||
let url = format!("{}/readout/manifest", self.base_url);
|
||||
let auth = format!("Bearer {}", self.api_key);
|
||||
let response = self
|
||||
.client
|
||||
.get_with_headers(&url, &[("Authorization", &auth)])
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
|
||||
let status = response.status();
|
||||
if status.as_u16() == 404 {
|
||||
return Ok::<_, anyhow::Error>(None);
|
||||
}
|
||||
if !status.is_success() {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let n = body.floor_char_boundary(body.len().min(500));
|
||||
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
|
||||
}
|
||||
Ok(Some(response.json().await?))
|
||||
}).await?;
|
||||
Ok(manifest.clone())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
async fn stream_completions(
|
||||
client: &HttpClient,
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
model: &str,
|
||||
prompt_tokens: &[u32],
|
||||
images: &[(Vec<u8>, String)],
|
||||
tx: &mpsc::UnboundedSender<StreamToken>,
|
||||
sampling: SamplingParams,
|
||||
priority: Option<i32>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut request = serde_json::json!({
|
||||
"model": model,
|
||||
"prompt": prompt_tokens,
|
||||
"max_tokens": 16384,
|
||||
"temperature": sampling.temperature,
|
||||
"top_p": sampling.top_p,
|
||||
"top_k": sampling.top_k,
|
||||
"stream": true,
|
||||
"return_token_ids": true,
|
||||
"skip_special_tokens": false,
|
||||
"stop_token_ids": [super::tokenizer::IM_END],
|
||||
});
|
||||
if !images.is_empty() {
|
||||
use base64::Engine;
|
||||
let b64 = base64::engine::general_purpose::STANDARD;
|
||||
let uris: Vec<String> = images.iter()
|
||||
.map(|(bytes, mime)| format!("data:{};base64,{}", mime, b64.encode(bytes)))
|
||||
.collect();
|
||||
request["multi_modal_data"] = serde_json::json!({ "image": uris });
|
||||
}
|
||||
if let Some(p) = priority {
|
||||
request["priority"] = serde_json::json!(p);
|
||||
}
|
||||
|
||||
let url = format!("{}/completions", base_url);
|
||||
let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model);
|
||||
|
||||
let mut response = send_and_check(
|
||||
client, &url, &request,
|
||||
("Authorization", &format!("Bearer {}", api_key)),
|
||||
&[], &debug_label, None,
|
||||
).await?;
|
||||
|
||||
let mut reader = SseReader::new();
|
||||
let mut usage = None;
|
||||
|
||||
while let Some(event) = reader.next_event(&mut response).await? {
|
||||
if let Some(err_msg) = event["error"]["message"].as_str() {
|
||||
anyhow::bail!("API error in stream: {}", err_msg);
|
||||
}
|
||||
|
||||
if let Some(u) = event["usage"].as_object() {
|
||||
if let Ok(u) = serde_json::from_value::<Usage>(serde_json::Value::Object(u.clone())) {
|
||||
usage = Some(u);
|
||||
}
|
||||
}
|
||||
|
||||
let choices = match event["choices"].as_array() {
|
||||
Some(c) => c,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
for choice in choices {
|
||||
// `readout`, if present, is a nested list
|
||||
// `[num_tokens][n_layers][n_concepts]`. Parse it once per
|
||||
// chunk and pair rows with token ids by index — the rows
|
||||
// are in the same order as `token_ids`.
|
||||
let readouts: Option<Vec<TokenReadout>> = choice["readout"]
|
||||
.as_array()
|
||||
.map(|outer| {
|
||||
outer.iter().filter_map(|per_token| {
|
||||
per_token.as_array().map(|layers| {
|
||||
layers.iter().filter_map(|per_layer| {
|
||||
per_layer.as_array().map(|vals| {
|
||||
vals.iter()
|
||||
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||
.collect::<Vec<f32>>()
|
||||
})
|
||||
}).collect::<Vec<Vec<f32>>>()
|
||||
})
|
||||
}).collect()
|
||||
});
|
||||
|
||||
if let Some(ids) = choice["token_ids"].as_array() {
|
||||
for (i, id_val) in ids.iter().enumerate() {
|
||||
if let Some(id) = id_val.as_u64() {
|
||||
let readout = readouts
|
||||
.as_ref()
|
||||
.and_then(|r| r.get(i).cloned());
|
||||
let _ = tx.send(StreamToken::Token {
|
||||
id: id as u32,
|
||||
readout,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let Some(text) = choice["text"].as_str() {
|
||||
// Fallback: provider didn't return token_ids, encode locally.
|
||||
// No readout available in this path — the encoder may
|
||||
// produce a different token count than the server did.
|
||||
if !text.is_empty() {
|
||||
for id in super::tokenizer::encode(text) {
|
||||
let _ = tx.send(StreamToken::Token { id, readout: None });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx.send(StreamToken::Done { usage });
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send an HTTP request and check for errors.
|
||||
pub(crate) async fn send_and_check(
|
||||
client: &HttpClient,
|
||||
url: &str,
|
||||
body: &impl serde::Serialize,
|
||||
auth_header: (&str, &str),
|
||||
extra_headers: &[(&str, &str)],
|
||||
debug_label: &str,
|
||||
request_json: Option<&str>,
|
||||
) -> Result<HttpResponse> {
|
||||
let debug = std::env::var("POC_DEBUG").is_ok();
|
||||
let start = Instant::now();
|
||||
|
||||
if debug {
|
||||
let payload_size = serde_json::to_string(body)
|
||||
.map(|s| s.len())
|
||||
.unwrap_or(0);
|
||||
dbglog!(
|
||||
"request: {}K payload, {}",
|
||||
payload_size / 1024, debug_label,
|
||||
);
|
||||
}
|
||||
|
||||
let mut headers: Vec<(&str, &str)> = Vec::with_capacity(extra_headers.len() + 1);
|
||||
headers.push(auth_header);
|
||||
headers.extend_from_slice(extra_headers);
|
||||
|
||||
let response = client
|
||||
.send_json("POST", url, &headers, body)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let msg = e.to_string();
|
||||
let cause = if msg.contains("connect timeout") || msg.contains("TCP connect") {
|
||||
"connection refused"
|
||||
} else if msg.contains("request timeout") {
|
||||
"request timed out"
|
||||
} else {
|
||||
"request error"
|
||||
};
|
||||
anyhow::anyhow!("{} ({}): {}", cause, url, msg)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
if debug {
|
||||
for name in [
|
||||
"x-ratelimit-remaining",
|
||||
"x-ratelimit-limit",
|
||||
"x-request-id",
|
||||
] {
|
||||
if let Some(val) = response.header(name) {
|
||||
dbglog!("header {}: {}", name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
dbglog!(
|
||||
"HTTP {} after {:.1}s ({}): {}",
|
||||
status,
|
||||
elapsed.as_secs_f64(),
|
||||
url,
|
||||
&body[..body.floor_char_boundary(body.len().min(500))]
|
||||
);
|
||||
if let Some(json) = request_json {
|
||||
let log_dir = dirs::home_dir()
|
||||
.unwrap_or_default()
|
||||
.join(".consciousness/logs/failed-requests");
|
||||
let _ = std::fs::create_dir_all(&log_dir);
|
||||
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
|
||||
let path = log_dir.join(format!("{}.json", ts));
|
||||
if std::fs::write(&path, json).is_ok() {
|
||||
dbglog!(
|
||||
"saved failed request to {} (HTTP {})", path.display(), status
|
||||
);
|
||||
}
|
||||
}
|
||||
anyhow::bail!("HTTP {} ({}): {}", status, url, &body[..body.floor_char_boundary(body.len().min(1000))]);
|
||||
}
|
||||
|
||||
if debug {
|
||||
dbglog!(
|
||||
"connected in {:.1}s (HTTP {})",
|
||||
elapsed.as_secs_f64(),
|
||||
status.as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// SSE stream reader. Handles the generic SSE plumbing shared by both
|
||||
/// backends: chunk reading with timeout, line buffering, `data:` prefix
|
||||
/// stripping, `[DONE]` detection, JSON parsing, and parse error diagnostics.
|
||||
/// Yields parsed events as serde_json::Value — each backend handles its
|
||||
/// own event types.
|
||||
pub(crate) struct SseReader {
|
||||
line_buf: String,
|
||||
chunk_timeout: Duration,
|
||||
pub stream_start: Instant,
|
||||
pub chunks_received: u64,
|
||||
pub sse_lines_parsed: u64,
|
||||
pub sse_parse_errors: u64,
|
||||
debug: bool,
|
||||
done: bool,
|
||||
/// Serialized request payload — saved to disk on errors for replay debugging.
|
||||
pub(crate) request_json: Option<String>,
|
||||
}
|
||||
|
||||
impl SseReader {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
line_buf: String::new(),
|
||||
chunk_timeout: Duration::from_secs(crate::config::get().api_stream_timeout_secs),
|
||||
stream_start: Instant::now(),
|
||||
chunks_received: 0,
|
||||
sse_lines_parsed: 0,
|
||||
sse_parse_errors: 0,
|
||||
debug: std::env::var("POC_DEBUG").is_ok(),
|
||||
done: false,
|
||||
request_json: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach the serialized request payload for error diagnostics.
|
||||
/// Save the request payload to disk for replay debugging.
|
||||
fn save_failed_request(&self, reason: &str) {
|
||||
let Some(ref json) = self.request_json else { return };
|
||||
let log_dir = dirs::home_dir()
|
||||
.unwrap_or_default()
|
||||
.join(".consciousness/logs/failed-requests");
|
||||
let _ = std::fs::create_dir_all(&log_dir);
|
||||
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
|
||||
let path = log_dir.join(format!("{}.json", ts));
|
||||
if std::fs::write(&path, json).is_ok() {
|
||||
dbglog!(
|
||||
"saved failed request to {} ({})", path.display(), reason
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the next SSE event from the response stream.
|
||||
/// Returns Ok(Some(value)) for each parsed data line,
|
||||
/// Ok(None) when the stream ends or [DONE] is received.
|
||||
pub(crate) async fn next_event(
|
||||
&mut self,
|
||||
response: &mut HttpResponse,
|
||||
) -> Result<Option<serde_json::Value>> {
|
||||
loop {
|
||||
// Drain complete lines from the buffer before reading more chunks
|
||||
while let Some(newline_pos) = self.line_buf.find('\n') {
|
||||
let line = self.line_buf[..newline_pos].trim().to_string();
|
||||
self.line_buf = self.line_buf[newline_pos + 1..].to_string();
|
||||
|
||||
if line == "data: [DONE]" {
|
||||
self.done = true;
|
||||
return Ok(None);
|
||||
}
|
||||
if line.is_empty()
|
||||
|| line.starts_with("event: ")
|
||||
|| !line.starts_with("data: ")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let json_str = &line[6..];
|
||||
self.sse_lines_parsed += 1;
|
||||
|
||||
match serde_json::from_str(json_str) {
|
||||
Ok(v) => return Ok(Some(v)),
|
||||
Err(e) => {
|
||||
self.sse_parse_errors += 1;
|
||||
if self.sse_parse_errors == 1 || self.debug {
|
||||
let preview = if json_str.len() > 200 {
|
||||
format!("{}...", &json_str[..200])
|
||||
} else {
|
||||
json_str.to_string()
|
||||
};
|
||||
dbglog!(
|
||||
"SSE parse error (#{}) {}: {}",
|
||||
self.sse_parse_errors, e, preview
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if self.done {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Read more data from the response stream
|
||||
match tokio::time::timeout(self.chunk_timeout, response.chunk()).await {
|
||||
Ok(Ok(Some(chunk))) => {
|
||||
self.chunks_received += 1;
|
||||
self.line_buf.push_str(&String::from_utf8_lossy(&chunk));
|
||||
}
|
||||
Ok(Ok(None)) => return Ok(None),
|
||||
Ok(Err(e)) => {
|
||||
let buf_preview = if self.line_buf.is_empty() {
|
||||
"(empty)".to_string()
|
||||
} else {
|
||||
let n = self.line_buf.len().min(500);
|
||||
format!("{}B: {}", self.line_buf.len(), &self.line_buf[..n])
|
||||
};
|
||||
let msg = format!(
|
||||
"stream error after {} chunks, {:.1}s, {} sse lines: {} | buf: {}",
|
||||
self.chunks_received,
|
||||
self.stream_start.elapsed().as_secs_f64(),
|
||||
self.sse_lines_parsed,
|
||||
e, buf_preview,
|
||||
);
|
||||
dbglog!("{}", msg);
|
||||
self.save_failed_request(&msg);
|
||||
return Err(e.into());
|
||||
}
|
||||
Err(_) => {
|
||||
let buf_preview = if self.line_buf.is_empty() {
|
||||
"(empty)".to_string()
|
||||
} else {
|
||||
let n = self.line_buf.len().min(500);
|
||||
format!("{}B: {}", self.line_buf.len(), &self.line_buf[..n])
|
||||
};
|
||||
let msg = format!(
|
||||
"stream timeout: {}s, {} chunks, {} sse lines, {:.1}s elapsed | buf: {}",
|
||||
self.chunk_timeout.as_secs(),
|
||||
self.chunks_received,
|
||||
self.sse_lines_parsed,
|
||||
self.stream_start.elapsed().as_secs_f64(),
|
||||
buf_preview,
|
||||
);
|
||||
dbglog!("{}", msg);
|
||||
self.save_failed_request(&msg);
|
||||
anyhow::bail!(
|
||||
"stream timeout: no data for {}s ({} chunks received)",
|
||||
self.chunk_timeout.as_secs(),
|
||||
self.chunks_received
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
249
src/agent/api/salience.rs
Normal file
249
src/agent/api/salience.rs
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
// agent/api/salience.rs — gRPC client bindings for salience.v1.
|
||||
//
|
||||
// Thin wrapper around the tonic-generated types. Every RPC except
|
||||
// Generate is unary; Generate is server-streaming. Free functions
|
||||
// (open/close session) wrap the lifecycle RPCs; `SessionHandle` just
|
||||
// carries the id + connection params so later RPCs can reuse them.
|
||||
//
|
||||
// The old bidi Session() API is gone — see git history for its shape.
|
||||
|
||||
#![allow(clippy::enum_variant_names)]
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint};
|
||||
|
||||
/// Generated prost + tonic types for salience.v1. Call sites use
|
||||
/// `pb::OpenSessionRequest`, `pb::Token`, etc.
|
||||
pub mod pb {
|
||||
tonic::include_proto!("salience.v1");
|
||||
}
|
||||
|
||||
pub type SalienceClient = pb::salience_client::SalienceClient<Channel>;
|
||||
|
||||
/// Open a TLS-aware gRPC channel to the salience server. `base_url`
|
||||
/// looks like `https://host:8443`. User-provided CA certs under
|
||||
/// `~/.consciousness/certs/` are trusted in addition to the system
|
||||
/// roots (for self-signed server certs).
|
||||
pub async fn connect(base_url: &str) -> Result<SalienceClient> {
|
||||
let mut endpoint = Endpoint::from_shared(base_url.to_string())
|
||||
.with_context(|| format!("invalid salience endpoint: {}", base_url))?
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.timeout(std::time::Duration::from_secs(600));
|
||||
|
||||
if base_url.starts_with("https://") {
|
||||
let user_certs = super::http::load_user_certs_pem_bytes();
|
||||
let mut tls = ClientTlsConfig::new().with_native_roots();
|
||||
if !user_certs.is_empty() {
|
||||
tls = tls.ca_certificate(Certificate::from_pem(user_certs));
|
||||
}
|
||||
endpoint = endpoint
|
||||
.tls_config(tls)
|
||||
.with_context(|| "configuring tonic TLS")?;
|
||||
}
|
||||
|
||||
let channel = endpoint
|
||||
.connect()
|
||||
.await
|
||||
.with_context(|| format!("failed to connect to salience server at {}", base_url))?;
|
||||
Ok(pb::salience_client::SalienceClient::new(channel))
|
||||
}
|
||||
|
||||
/// Derive the gRPC base URL from the HTTP completions base URL.
|
||||
///
|
||||
/// vLLM's salience gRPC server listens on a different port (8443) from
|
||||
/// the HTTP endpoint (8000) and accepts no path component. Given an
|
||||
/// HTTP base like `https://host:8000/v1`, produce `https://host:8443`.
|
||||
/// No-op when the path is empty and the port isn't 8000.
|
||||
pub fn derive_grpc_url(http_base: &str) -> String {
|
||||
let mut url = http_base.trim_end_matches('/').to_string();
|
||||
if let Some(proto_end) = url.find("://") {
|
||||
let rest_start = proto_end + 3;
|
||||
if let Some(path_slash) = url[rest_start..].find('/') {
|
||||
url.truncate(rest_start + path_slash);
|
||||
}
|
||||
}
|
||||
url.replace(":8000", ":8443")
|
||||
}
|
||||
|
||||
/// Attach a bearer token to a tonic request as gRPC metadata.
|
||||
pub fn with_auth<T>(req: &mut tonic::Request<T>, api_key: &str) {
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
let bearer = format!("Bearer {}", api_key);
|
||||
if let Ok(val) = bearer.parse() {
|
||||
req.metadata_mut().insert("authorization", val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Call the server's `OpenSession` RPC and return the response.
|
||||
pub async fn open_session(
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
model: &str,
|
||||
) -> Result<pb::OpenSessionResponse> {
|
||||
let mut client = connect(base_url).await?;
|
||||
let mut req = tonic::Request::new(pb::OpenSessionRequest {
|
||||
model: model.to_string(),
|
||||
});
|
||||
with_auth(&mut req, api_key);
|
||||
let resp = client
|
||||
.open_session(req)
|
||||
.await
|
||||
.with_context(|| "OpenSession RPC failed")?;
|
||||
Ok(resp.into_inner())
|
||||
}
|
||||
|
||||
/// Call the server's `CloseSession` RPC. Idempotent on the server.
|
||||
pub async fn close_session(base_url: &str, api_key: &str, session_id: &str) -> Result<()> {
|
||||
let mut client = connect(base_url).await?;
|
||||
let mut req = tonic::Request::new(pb::CloseSessionRequest {
|
||||
session_id: session_id.to_string(),
|
||||
});
|
||||
with_auth(&mut req, api_key);
|
||||
client
|
||||
.close_session(req)
|
||||
.await
|
||||
.with_context(|| "CloseSession RPC failed")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append an image to a session. Server decodes the image, computes N
|
||||
/// via vLLM's own multimodal pipeline, writes the full vision block
|
||||
/// (`<|vision_start|> + IMAGE_PAD×N + <|vision_end|>`) into
|
||||
/// session.tokens, and returns (N, new total length).
|
||||
///
|
||||
/// `offset` is the client's view of the session's current token count;
|
||||
/// the server rejects if it diverges from its own (unless
|
||||
/// `truncating=true`, in which case the server slices to `offset`
|
||||
/// first — but never through a vision block).
|
||||
pub async fn append_image(
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
session_id: &str,
|
||||
data: Vec<u8>,
|
||||
mime: String,
|
||||
offset: u32,
|
||||
truncating: bool,
|
||||
) -> Result<pb::AppendImageResponse> {
|
||||
let mut client = connect(base_url).await?;
|
||||
let mut req = tonic::Request::new(pb::AppendImageRequest {
|
||||
session_id: session_id.to_string(),
|
||||
data,
|
||||
mime,
|
||||
offset,
|
||||
truncating,
|
||||
});
|
||||
with_auth(&mut req, api_key);
|
||||
let resp = client
|
||||
.append_image(req)
|
||||
.await
|
||||
.with_context(|| "AppendImage RPC failed")?;
|
||||
Ok(resp.into_inner())
|
||||
}
|
||||
|
||||
/// Handle to a server-side session. Carries the id + connection params
|
||||
/// so subsequent per-session RPCs (AppendImage, Generate, ForkSession)
|
||||
/// can be issued without the caller juggling base_url / api_key each
|
||||
/// time.
|
||||
pub struct SessionHandle {
|
||||
pub session_id: String,
|
||||
pub max_model_len: u32,
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
impl SessionHandle {
|
||||
pub async fn open(base_url: &str, api_key: &str, model: &str) -> Result<Self> {
|
||||
let grpc_url = derive_grpc_url(base_url);
|
||||
log::debug!(target: "grpc",
|
||||
"SessionHandle::open http_base={} -> grpc_url={}",
|
||||
base_url, grpc_url);
|
||||
let resp = open_session(&grpc_url, api_key, model).await?;
|
||||
log::debug!(target: "grpc",
|
||||
"SessionHandle::open session_id={} max_model_len={}",
|
||||
resp.session_id, resp.max_model_len);
|
||||
Ok(Self {
|
||||
session_id: resp.session_id,
|
||||
max_model_len: resp.max_model_len,
|
||||
base_url: grpc_url,
|
||||
api_key: api_key.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn close(self) -> Result<()> {
|
||||
close_session(&self.base_url, &self.api_key, &self.session_id).await
|
||||
}
|
||||
|
||||
/// Append an image via the server-side vision block. See
|
||||
/// `append_image` free function for full semantics.
|
||||
pub async fn append_image(
|
||||
&self,
|
||||
data: Vec<u8>,
|
||||
mime: String,
|
||||
offset: u32,
|
||||
truncating: bool,
|
||||
) -> Result<pb::AppendImageResponse> {
|
||||
append_image(
|
||||
&self.base_url,
|
||||
&self.api_key,
|
||||
&self.session_id,
|
||||
data,
|
||||
mime,
|
||||
offset,
|
||||
truncating,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn generated_types_compile() {
|
||||
// Exercise the shape of the new proto types — if build.rs
|
||||
// stops regenerating against the proto, this stops compiling.
|
||||
let _open = pb::OpenSessionRequest {
|
||||
model: "qwen3-vl".into(),
|
||||
};
|
||||
let _tok = pb::Token {
|
||||
id: 42,
|
||||
position: 0,
|
||||
is_prefill: false,
|
||||
readout: vec![0.1, 0.2, 0.3],
|
||||
logprobs: vec![pb::TokenLogprob {
|
||||
id: 1,
|
||||
logprob: -0.5,
|
||||
}],
|
||||
sampled_logprob: -0.1,
|
||||
has_sampled_logprob: true,
|
||||
};
|
||||
let _done = pb::GenerateDone {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 20,
|
||||
total_tokens: 30,
|
||||
finish_reason: pb::generate_done::FinishReason::Eos as i32,
|
||||
};
|
||||
let _evt = pb::GenerateEvent {
|
||||
event: Some(pb::generate_event::Event::Done(_done)),
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_grpc_url_cases() {
|
||||
assert_eq!(
|
||||
derive_grpc_url("https://host:8000/v1"),
|
||||
"https://host:8443",
|
||||
);
|
||||
assert_eq!(
|
||||
derive_grpc_url("https://host:8000/"),
|
||||
"https://host:8443",
|
||||
);
|
||||
assert_eq!(
|
||||
derive_grpc_url("https://host:9000/v1"),
|
||||
"https://host:9000",
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue