diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index 8c03bd4..be5e58e 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -22,6 +22,21 @@ pub struct Usage { pub total_tokens: u32, } +/// Concept-readout manifest returned by the vLLM server's +/// `/v1/readout/manifest` endpoint. Maps the nameless tensor indices +/// in streaming `readout` fields back to concept names and layer +/// indices. +#[derive(Debug, Clone, Deserialize)] +pub struct ReadoutManifest { + pub concepts: Vec, + pub layers: Vec, +} + +/// Per-token per-layer concept projections streamed alongside each +/// sampled token. Shape `[n_layers][n_concepts]`. Named values come +/// from pairing with the manifest fetched at startup. +pub type TokenReadout = Vec>; + /// A JoinHandle that aborts its task when dropped. pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>); @@ -45,7 +60,10 @@ pub(crate) struct SamplingParams { /// One token from the streaming completions API. pub enum StreamToken { - Token(u32), + /// A sampled token, optionally with its per-layer concept readout. + /// `readout` is `None` when the server has readout disabled or + /// returned no readout for this chunk. + Token { id: u32, readout: Option }, Done { usage: Option }, Error(String), } @@ -106,6 +124,32 @@ impl ApiClient { pub fn base_url(&self) -> &str { &self.base_url } pub fn api_key(&self) -> &str { &self.api_key } + /// Fetch `/v1/readout/manifest` — returns `Ok(Some(..))` if + /// readout is enabled on the server, `Ok(None)` on 404 (disabled), + /// or an error on any other failure. + /// + /// Call once at startup and cache the result; the manifest doesn't + /// change during a server run. + pub async fn fetch_readout_manifest(&self) -> Result> { + let url = format!("{}/readout/manifest", self.base_url); + let auth = format!("Bearer {}", self.api_key); + let response = self + .client + .get_with_headers(&url, &[("Authorization", &auth)]) + .await + .map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?; + let status = response.status(); + if status.as_u16() == 404 { + return Ok(None); + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + let n = body.floor_char_boundary(body.len().min(500)); + anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]); + } + Ok(Some(response.json().await?)) + } + } async fn stream_completions( @@ -172,17 +216,45 @@ async fn stream_completions( }; for choice in choices { + // `readout`, if present, is a nested list + // `[num_tokens][n_layers][n_concepts]`. Parse it once per + // chunk and pair rows with token ids by index — the rows + // are in the same order as `token_ids`. + let readouts: Option> = choice["readout"] + .as_array() + .map(|outer| { + outer.iter().filter_map(|per_token| { + per_token.as_array().map(|layers| { + layers.iter().filter_map(|per_layer| { + per_layer.as_array().map(|vals| { + vals.iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect::>() + }) + }).collect::>>() + }) + }).collect() + }); + if let Some(ids) = choice["token_ids"].as_array() { - for id_val in ids { + for (i, id_val) in ids.iter().enumerate() { if let Some(id) = id_val.as_u64() { - let _ = tx.send(StreamToken::Token(id as u32)); + let readout = readouts + .as_ref() + .and_then(|r| r.get(i).cloned()); + let _ = tx.send(StreamToken::Token { + id: id as u32, + readout, + }); } } } else if let Some(text) = choice["text"].as_str() { - // Fallback: provider didn't return token_ids, encode locally + // Fallback: provider didn't return token_ids, encode locally. + // No readout available in this path — the encoder may + // produce a different token count than the server did. if !text.is_empty() { for id in super::tokenizer::encode(text) { - let _ = tx.send(StreamToken::Token(id)); + let _ = tx.send(StreamToken::Token { id, readout: None }); } } } diff --git a/src/agent/context.rs b/src/agent/context.rs index 948e9f2..49b9998 100644 --- a/src/agent/context.rs +++ b/src/agent/context.rs @@ -682,7 +682,7 @@ impl ResponseParser { let mut full_text = String::new(); while let Some(event) = stream.recv().await { match event { - super::api::StreamToken::Token(id) => { + super::api::StreamToken::Token { id, readout: _ } => { let text = super::tokenizer::decode(&[id]); full_text.push_str(&text); let mut ctx = agent.context.lock().await; diff --git a/src/subconscious/generate.rs b/src/subconscious/generate.rs index 44f967a..8d75f1b 100644 --- a/src/subconscious/generate.rs +++ b/src/subconscious/generate.rs @@ -36,7 +36,7 @@ where F: FnMut(&AstNode) -> bool, let mut tokens = Vec::new(); while let Some(tok) = rx.recv().await { match tok { - StreamToken::Token(id) => tokens.push(id), + StreamToken::Token { id, .. } => tokens.push(id), StreamToken::Done { .. } => break, StreamToken::Error(e) => anyhow::bail!("generation error: {}", e), }