agent/api: carry readout alongside streamed tokens

StreamToken::Token is now a struct variant with an optional
TokenReadout (shape [n_layers][n_concepts]) per token — parsed from
the vLLM completion response's choices[i].readout field when the
server has readout enabled.

ApiClient gains a fetch_readout_manifest() method that hits
GET /v1/readout/manifest. Returns Ok(None) on 404 (server has
readout disabled), so callers can gracefully fall back when pointed
at a non-readout-enabled endpoint.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 01:15:46 -04:00
parent 047da10123
commit 0f1c4cf1de
3 changed files with 79 additions and 7 deletions

View file

@ -22,6 +22,21 @@ pub struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
/// Concept-readout manifest returned by the vLLM server's
/// `/v1/readout/manifest` endpoint. Maps the nameless tensor indices
/// in streaming `readout` fields back to concept names and layer
/// indices.
#[derive(Debug, Clone, Deserialize)]
pub struct ReadoutManifest {
pub concepts: Vec<String>,
pub layers: Vec<u32>,
}
/// Per-token per-layer concept projections streamed alongside each
/// sampled token. Shape `[n_layers][n_concepts]`. Named values come
/// from pairing with the manifest fetched at startup.
pub type TokenReadout = Vec<Vec<f32>>;
/// A JoinHandle that aborts its task when dropped. /// A JoinHandle that aborts its task when dropped.
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>); pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
@ -45,7 +60,10 @@ pub(crate) struct SamplingParams {
/// One token from the streaming completions API. /// One token from the streaming completions API.
pub enum StreamToken { pub enum StreamToken {
Token(u32), /// A sampled token, optionally with its per-layer concept readout.
/// `readout` is `None` when the server has readout disabled or
/// returned no readout for this chunk.
Token { id: u32, readout: Option<TokenReadout> },
Done { usage: Option<Usage> }, Done { usage: Option<Usage> },
Error(String), Error(String),
} }
@ -106,6 +124,32 @@ impl ApiClient {
pub fn base_url(&self) -> &str { &self.base_url } pub fn base_url(&self) -> &str { &self.base_url }
pub fn api_key(&self) -> &str { &self.api_key } pub fn api_key(&self) -> &str { &self.api_key }
/// Fetch `/v1/readout/manifest` — returns `Ok(Some(..))` if
/// readout is enabled on the server, `Ok(None)` on 404 (disabled),
/// or an error on any other failure.
///
/// Call once at startup and cache the result; the manifest doesn't
/// change during a server run.
pub async fn fetch_readout_manifest(&self) -> Result<Option<ReadoutManifest>> {
let url = format!("{}/readout/manifest", self.base_url);
let auth = format!("Bearer {}", self.api_key);
let response = self
.client
.get_with_headers(&url, &[("Authorization", &auth)])
.await
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
let status = response.status();
if status.as_u16() == 404 {
return Ok(None);
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
let n = body.floor_char_boundary(body.len().min(500));
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
}
Ok(Some(response.json().await?))
}
} }
async fn stream_completions( async fn stream_completions(
@ -172,17 +216,45 @@ async fn stream_completions(
}; };
for choice in choices { for choice in choices {
// `readout`, if present, is a nested list
// `[num_tokens][n_layers][n_concepts]`. Parse it once per
// chunk and pair rows with token ids by index — the rows
// are in the same order as `token_ids`.
let readouts: Option<Vec<TokenReadout>> = choice["readout"]
.as_array()
.map(|outer| {
outer.iter().filter_map(|per_token| {
per_token.as_array().map(|layers| {
layers.iter().filter_map(|per_layer| {
per_layer.as_array().map(|vals| {
vals.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect::<Vec<f32>>()
})
}).collect::<Vec<Vec<f32>>>()
})
}).collect()
});
if let Some(ids) = choice["token_ids"].as_array() { if let Some(ids) = choice["token_ids"].as_array() {
for id_val in ids { for (i, id_val) in ids.iter().enumerate() {
if let Some(id) = id_val.as_u64() { if let Some(id) = id_val.as_u64() {
let _ = tx.send(StreamToken::Token(id as u32)); let readout = readouts
.as_ref()
.and_then(|r| r.get(i).cloned());
let _ = tx.send(StreamToken::Token {
id: id as u32,
readout,
});
} }
} }
} else if let Some(text) = choice["text"].as_str() { } else if let Some(text) = choice["text"].as_str() {
// Fallback: provider didn't return token_ids, encode locally // Fallback: provider didn't return token_ids, encode locally.
// No readout available in this path — the encoder may
// produce a different token count than the server did.
if !text.is_empty() { if !text.is_empty() {
for id in super::tokenizer::encode(text) { for id in super::tokenizer::encode(text) {
let _ = tx.send(StreamToken::Token(id)); let _ = tx.send(StreamToken::Token { id, readout: None });
} }
} }
} }

View file

@ -682,7 +682,7 @@ impl ResponseParser {
let mut full_text = String::new(); let mut full_text = String::new();
while let Some(event) = stream.recv().await { while let Some(event) = stream.recv().await {
match event { match event {
super::api::StreamToken::Token(id) => { super::api::StreamToken::Token { id, readout: _ } => {
let text = super::tokenizer::decode(&[id]); let text = super::tokenizer::decode(&[id]);
full_text.push_str(&text); full_text.push_str(&text);
let mut ctx = agent.context.lock().await; let mut ctx = agent.context.lock().await;

View file

@ -36,7 +36,7 @@ where F: FnMut(&AstNode) -> bool,
let mut tokens = Vec::new(); let mut tokens = Vec::new();
while let Some(tok) = rx.recv().await { while let Some(tok) = rx.recv().await {
match tok { match tok {
StreamToken::Token(id) => tokens.push(id), StreamToken::Token { id, .. } => tokens.push(id),
StreamToken::Done { .. } => break, StreamToken::Done { .. } => break,
StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), StreamToken::Error(e) => anyhow::bail!("generation error: {}", e),
} }