agent/api: carry readout alongside streamed tokens
StreamToken::Token is now a struct variant with an optional TokenReadout (shape [n_layers][n_concepts]) per token — parsed from the vLLM completion response's choices[i].readout field when the server has readout enabled. ApiClient gains a fetch_readout_manifest() method that hits GET /v1/readout/manifest. Returns Ok(None) on 404 (server has readout disabled), so callers can gracefully fall back when pointed at a non-readout-enabled endpoint. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
047da10123
commit
0f1c4cf1de
3 changed files with 79 additions and 7 deletions
|
|
@ -22,6 +22,21 @@ pub struct Usage {
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Concept-readout manifest returned by the vLLM server's
|
||||||
|
/// `/v1/readout/manifest` endpoint. Maps the nameless tensor indices
|
||||||
|
/// in streaming `readout` fields back to concept names and layer
|
||||||
|
/// indices.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ReadoutManifest {
|
||||||
|
pub concepts: Vec<String>,
|
||||||
|
pub layers: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Per-token per-layer concept projections streamed alongside each
|
||||||
|
/// sampled token. Shape `[n_layers][n_concepts]`. Named values come
|
||||||
|
/// from pairing with the manifest fetched at startup.
|
||||||
|
pub type TokenReadout = Vec<Vec<f32>>;
|
||||||
|
|
||||||
/// A JoinHandle that aborts its task when dropped.
|
/// A JoinHandle that aborts its task when dropped.
|
||||||
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||||
|
|
||||||
|
|
@ -45,7 +60,10 @@ pub(crate) struct SamplingParams {
|
||||||
|
|
||||||
/// One token from the streaming completions API.
|
/// One token from the streaming completions API.
|
||||||
pub enum StreamToken {
|
pub enum StreamToken {
|
||||||
Token(u32),
|
/// A sampled token, optionally with its per-layer concept readout.
|
||||||
|
/// `readout` is `None` when the server has readout disabled or
|
||||||
|
/// returned no readout for this chunk.
|
||||||
|
Token { id: u32, readout: Option<TokenReadout> },
|
||||||
Done { usage: Option<Usage> },
|
Done { usage: Option<Usage> },
|
||||||
Error(String),
|
Error(String),
|
||||||
}
|
}
|
||||||
|
|
@ -106,6 +124,32 @@ impl ApiClient {
|
||||||
pub fn base_url(&self) -> &str { &self.base_url }
|
pub fn base_url(&self) -> &str { &self.base_url }
|
||||||
pub fn api_key(&self) -> &str { &self.api_key }
|
pub fn api_key(&self) -> &str { &self.api_key }
|
||||||
|
|
||||||
|
/// Fetch `/v1/readout/manifest` — returns `Ok(Some(..))` if
|
||||||
|
/// readout is enabled on the server, `Ok(None)` on 404 (disabled),
|
||||||
|
/// or an error on any other failure.
|
||||||
|
///
|
||||||
|
/// Call once at startup and cache the result; the manifest doesn't
|
||||||
|
/// change during a server run.
|
||||||
|
pub async fn fetch_readout_manifest(&self) -> Result<Option<ReadoutManifest>> {
|
||||||
|
let url = format!("{}/readout/manifest", self.base_url);
|
||||||
|
let auth = format!("Bearer {}", self.api_key);
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.get_with_headers(&url, &[("Authorization", &auth)])
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
|
||||||
|
let status = response.status();
|
||||||
|
if status.as_u16() == 404 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
if !status.is_success() {
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
let n = body.floor_char_boundary(body.len().min(500));
|
||||||
|
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
|
||||||
|
}
|
||||||
|
Ok(Some(response.json().await?))
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_completions(
|
async fn stream_completions(
|
||||||
|
|
@ -172,17 +216,45 @@ async fn stream_completions(
|
||||||
};
|
};
|
||||||
|
|
||||||
for choice in choices {
|
for choice in choices {
|
||||||
|
// `readout`, if present, is a nested list
|
||||||
|
// `[num_tokens][n_layers][n_concepts]`. Parse it once per
|
||||||
|
// chunk and pair rows with token ids by index — the rows
|
||||||
|
// are in the same order as `token_ids`.
|
||||||
|
let readouts: Option<Vec<TokenReadout>> = choice["readout"]
|
||||||
|
.as_array()
|
||||||
|
.map(|outer| {
|
||||||
|
outer.iter().filter_map(|per_token| {
|
||||||
|
per_token.as_array().map(|layers| {
|
||||||
|
layers.iter().filter_map(|per_layer| {
|
||||||
|
per_layer.as_array().map(|vals| {
|
||||||
|
vals.iter()
|
||||||
|
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||||
|
.collect::<Vec<f32>>()
|
||||||
|
})
|
||||||
|
}).collect::<Vec<Vec<f32>>>()
|
||||||
|
})
|
||||||
|
}).collect()
|
||||||
|
});
|
||||||
|
|
||||||
if let Some(ids) = choice["token_ids"].as_array() {
|
if let Some(ids) = choice["token_ids"].as_array() {
|
||||||
for id_val in ids {
|
for (i, id_val) in ids.iter().enumerate() {
|
||||||
if let Some(id) = id_val.as_u64() {
|
if let Some(id) = id_val.as_u64() {
|
||||||
let _ = tx.send(StreamToken::Token(id as u32));
|
let readout = readouts
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|r| r.get(i).cloned());
|
||||||
|
let _ = tx.send(StreamToken::Token {
|
||||||
|
id: id as u32,
|
||||||
|
readout,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if let Some(text) = choice["text"].as_str() {
|
} else if let Some(text) = choice["text"].as_str() {
|
||||||
// Fallback: provider didn't return token_ids, encode locally
|
// Fallback: provider didn't return token_ids, encode locally.
|
||||||
|
// No readout available in this path — the encoder may
|
||||||
|
// produce a different token count than the server did.
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
for id in super::tokenizer::encode(text) {
|
for id in super::tokenizer::encode(text) {
|
||||||
let _ = tx.send(StreamToken::Token(id));
|
let _ = tx.send(StreamToken::Token { id, readout: None });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -682,7 +682,7 @@ impl ResponseParser {
|
||||||
let mut full_text = String::new();
|
let mut full_text = String::new();
|
||||||
while let Some(event) = stream.recv().await {
|
while let Some(event) = stream.recv().await {
|
||||||
match event {
|
match event {
|
||||||
super::api::StreamToken::Token(id) => {
|
super::api::StreamToken::Token { id, readout: _ } => {
|
||||||
let text = super::tokenizer::decode(&[id]);
|
let text = super::tokenizer::decode(&[id]);
|
||||||
full_text.push_str(&text);
|
full_text.push_str(&text);
|
||||||
let mut ctx = agent.context.lock().await;
|
let mut ctx = agent.context.lock().await;
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ where F: FnMut(&AstNode) -> bool,
|
||||||
let mut tokens = Vec::new();
|
let mut tokens = Vec::new();
|
||||||
while let Some(tok) = rx.recv().await {
|
while let Some(tok) = rx.recv().await {
|
||||||
match tok {
|
match tok {
|
||||||
StreamToken::Token(id) => tokens.push(id),
|
StreamToken::Token { id, .. } => tokens.push(id),
|
||||||
StreamToken::Done { .. } => break,
|
StreamToken::Done { .. } => break,
|
||||||
StreamToken::Error(e) => anyhow::bail!("generation error: {}", e),
|
StreamToken::Error(e) => anyhow::bail!("generation error: {}", e),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue