consciousness/src/agent/api/mod.rs

514 lines
18 KiB
Rust
Raw Normal View History

// api/ — LLM API client (OpenAI-compatible)
//
// Works with any provider that implements the OpenAI chat completions
// API: OpenRouter, vLLM, llama.cpp, Fireworks, Together, etc.
//
// Diagnostics: anomalies always logged to debug panel.
// Set POC_DEBUG=1 for verbose per-turn logging.
pub mod http;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::mpsc;
use serde::Deserialize;
use http::{HttpClient, HttpResponse};
#[derive(Debug, Clone, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
/// Concept-readout manifest returned by the vLLM server's
/// `/v1/readout/manifest` endpoint. Maps the nameless tensor indices
/// in streaming `readout` fields back to concept names and layer
/// indices.
#[derive(Debug, Clone, Deserialize)]
pub struct ReadoutManifest {
pub concepts: Vec<String>,
pub layers: Vec<u32>,
}
/// Per-token per-layer concept projections streamed alongside each
/// sampled token. Shape `[n_layers][n_concepts]`. Named values come
/// from pairing with the manifest fetched at startup.
pub type TokenReadout = Vec<Vec<f32>>;
/// A JoinHandle that aborts its task when dropped.
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
/// Sampling parameters for model generation.
#[derive(Clone, Copy)]
pub(crate) struct SamplingParams {
pub temperature: f32,
pub top_p: f32,
pub top_k: u32,
}
// ─────────────────────────────────────────────────────────────
// Stream events — yielded by backends, consumed by the runner
// ─────────────────────────────────────────────────────────────
/// One token from the streaming completions API.
pub enum StreamToken {
/// A sampled token, optionally with its per-layer concept readout.
/// `readout` is `None` when the server has readout disabled or
/// returned no readout for this chunk.
Token { id: u32, readout: Option<TokenReadout> },
Done { usage: Option<Usage> },
Error(String),
}
#[derive(Clone)]
pub struct ApiClient {
client: HttpClient,
api_key: String,
pub model: String,
base_url: String,
}
impl ApiClient {
pub fn new(base_url: &str, api_key: &str, model: &str) -> Self {
let client = HttpClient::builder()
.connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(600))
.build();
Self {
client,
api_key: api_key.to_string(),
model: model.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
}
}
pub(crate) fn stream_completion_mm(
&self,
prompt_tokens: &[u32],
images: &[super::context::WireImage],
sampling: SamplingParams,
priority: Option<i32>,
) -> (mpsc::UnboundedReceiver<StreamToken>, AbortOnDrop) {
let (tx, rx) = mpsc::unbounded_channel();
let client = self.client.clone();
let api_key = self.api_key.clone();
let model = self.model.clone();
let prompt_tokens = prompt_tokens.to_vec();
let images: Vec<(Vec<u8>, String)> = images.iter()
.map(|i| (i.bytes.clone(), i.mime.clone()))
.collect();
let base_url = self.base_url.clone();
let handle = tokio::spawn(async move {
let result = stream_completions(
&client, &base_url, &api_key, &model,
&prompt_tokens, &images, &tx, sampling, priority,
).await;
if let Err(e) = result {
let _ = tx.send(StreamToken::Error(e.to_string()));
}
});
(rx, AbortOnDrop(handle))
}
pub fn base_url(&self) -> &str { &self.base_url }
pub fn api_key(&self) -> &str { &self.api_key }
/// Fetch `/v1/readout/manifest` — returns `Ok(Some(..))` if
/// readout is enabled on the server, `Ok(None)` on 404 (disabled),
/// or an error on any other failure.
///
/// Call once at startup and cache the result; the manifest doesn't
/// change during a server run.
pub async fn fetch_readout_manifest(&self) -> Result<Option<ReadoutManifest>> {
let url = format!("{}/readout/manifest", self.base_url);
let auth = format!("Bearer {}", self.api_key);
let response = self
.client
.get_with_headers(&url, &[("Authorization", &auth)])
.await
.map_err(|e| anyhow::anyhow!("readout manifest fetch ({}): {}", url, e))?;
let status = response.status();
if status.as_u16() == 404 {
return Ok(None);
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
let n = body.floor_char_boundary(body.len().min(500));
anyhow::bail!("readout manifest HTTP {} ({}): {}", status, url, &body[..n]);
}
Ok(Some(response.json().await?))
}
}
async fn stream_completions(
client: &HttpClient,
base_url: &str,
api_key: &str,
model: &str,
prompt_tokens: &[u32],
images: &[(Vec<u8>, String)],
tx: &mpsc::UnboundedSender<StreamToken>,
sampling: SamplingParams,
priority: Option<i32>,
) -> anyhow::Result<()> {
let mut request = serde_json::json!({
"model": model,
"prompt": prompt_tokens,
"max_tokens": 16384,
"temperature": sampling.temperature,
"top_p": sampling.top_p,
"top_k": sampling.top_k,
"stream": true,
"return_token_ids": true,
"skip_special_tokens": false,
"stop_token_ids": [super::tokenizer::IM_END],
});
if !images.is_empty() {
use base64::Engine;
let b64 = base64::engine::general_purpose::STANDARD;
let uris: Vec<String> = images.iter()
.map(|(bytes, mime)| format!("data:{};base64,{}", mime, b64.encode(bytes)))
.collect();
request["multi_modal_data"] = serde_json::json!({ "image": uris });
}
if let Some(p) = priority {
request["priority"] = serde_json::json!(p);
}
let url = format!("{}/completions", base_url);
let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model);
let mut response = send_and_check(
client, &url, &request,
("Authorization", &format!("Bearer {}", api_key)),
&[], &debug_label, None,
).await?;
let mut reader = SseReader::new();
let mut usage = None;
while let Some(event) = reader.next_event(&mut response).await? {
if let Some(err_msg) = event["error"]["message"].as_str() {
anyhow::bail!("API error in stream: {}", err_msg);
}
if let Some(u) = event["usage"].as_object() {
if let Ok(u) = serde_json::from_value::<Usage>(serde_json::Value::Object(u.clone())) {
usage = Some(u);
}
}
let choices = match event["choices"].as_array() {
Some(c) => c,
None => continue,
};
for choice in choices {
// `readout`, if present, is a nested list
// `[num_tokens][n_layers][n_concepts]`. Parse it once per
// chunk and pair rows with token ids by index — the rows
// are in the same order as `token_ids`.
let readouts: Option<Vec<TokenReadout>> = choice["readout"]
.as_array()
.map(|outer| {
outer.iter().filter_map(|per_token| {
per_token.as_array().map(|layers| {
layers.iter().filter_map(|per_layer| {
per_layer.as_array().map(|vals| {
vals.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect::<Vec<f32>>()
})
}).collect::<Vec<Vec<f32>>>()
})
}).collect()
});
if let Some(ids) = choice["token_ids"].as_array() {
for (i, id_val) in ids.iter().enumerate() {
if let Some(id) = id_val.as_u64() {
let readout = readouts
.as_ref()
.and_then(|r| r.get(i).cloned());
let _ = tx.send(StreamToken::Token {
id: id as u32,
readout,
});
}
}
} else if let Some(text) = choice["text"].as_str() {
// Fallback: provider didn't return token_ids, encode locally.
// No readout available in this path — the encoder may
// produce a different token count than the server did.
if !text.is_empty() {
for id in super::tokenizer::encode(text) {
let _ = tx.send(StreamToken::Token { id, readout: None });
}
}
}
}
}
let _ = tx.send(StreamToken::Done { usage });
Ok(())
}
/// Send an HTTP request and check for errors.
pub(crate) async fn send_and_check(
client: &HttpClient,
url: &str,
body: &impl serde::Serialize,
auth_header: (&str, &str),
extra_headers: &[(&str, &str)],
debug_label: &str,
request_json: Option<&str>,
) -> Result<HttpResponse> {
let debug = std::env::var("POC_DEBUG").is_ok();
let start = Instant::now();
if debug {
let payload_size = serde_json::to_string(body)
.map(|s| s.len())
.unwrap_or(0);
dbglog!(
"request: {}K payload, {}",
payload_size / 1024, debug_label,
);
}
let mut headers: Vec<(&str, &str)> = Vec::with_capacity(extra_headers.len() + 1);
headers.push(auth_header);
headers.extend_from_slice(extra_headers);
let response = client
.send_json("POST", url, &headers, body)
.await
.map_err(|e| {
let msg = e.to_string();
let cause = if msg.contains("connect timeout") || msg.contains("TCP connect") {
"connection refused"
} else if msg.contains("request timeout") {
"request timed out"
} else {
"request error"
};
anyhow::anyhow!("{} ({}): {}", cause, url, msg)
})?;
let status = response.status();
let elapsed = start.elapsed();
if debug {
for name in [
"x-ratelimit-remaining",
"x-ratelimit-limit",
"x-request-id",
] {
if let Some(val) = response.header(name) {
dbglog!("header {}: {}", name, val);
}
}
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
dbglog!(
"HTTP {} after {:.1}s ({}): {}",
status,
elapsed.as_secs_f64(),
url,
&body[..body.floor_char_boundary(body.len().min(500))]
);
if let Some(json) = request_json {
let log_dir = dirs::home_dir()
.unwrap_or_default()
.join(".consciousness/logs/failed-requests");
let _ = std::fs::create_dir_all(&log_dir);
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
let path = log_dir.join(format!("{}.json", ts));
if std::fs::write(&path, json).is_ok() {
dbglog!(
"saved failed request to {} (HTTP {})", path.display(), status
);
}
}
anyhow::bail!("HTTP {} ({}): {}", status, url, &body[..body.floor_char_boundary(body.len().min(1000))]);
}
if debug {
dbglog!(
"connected in {:.1}s (HTTP {})",
elapsed.as_secs_f64(),
status.as_u16()
);
}
Ok(response)
}
/// SSE stream reader. Handles the generic SSE plumbing shared by both
/// backends: chunk reading with timeout, line buffering, `data:` prefix
/// stripping, `[DONE]` detection, JSON parsing, and parse error diagnostics.
/// Yields parsed events as serde_json::Value — each backend handles its
/// own event types.
pub(crate) struct SseReader {
line_buf: String,
chunk_timeout: Duration,
pub stream_start: Instant,
pub chunks_received: u64,
pub sse_lines_parsed: u64,
pub sse_parse_errors: u64,
debug: bool,
done: bool,
/// Serialized request payload — saved to disk on errors for replay debugging.
pub(crate) request_json: Option<String>,
}
impl SseReader {
pub(crate) fn new() -> Self {
Self {
line_buf: String::new(),
chunk_timeout: Duration::from_secs(crate::config::get().api_stream_timeout_secs),
stream_start: Instant::now(),
chunks_received: 0,
sse_lines_parsed: 0,
sse_parse_errors: 0,
debug: std::env::var("POC_DEBUG").is_ok(),
done: false,
request_json: None,
}
}
/// Attach the serialized request payload for error diagnostics.
/// Save the request payload to disk for replay debugging.
fn save_failed_request(&self, reason: &str) {
let Some(ref json) = self.request_json else { return };
let log_dir = dirs::home_dir()
.unwrap_or_default()
.join(".consciousness/logs/failed-requests");
let _ = std::fs::create_dir_all(&log_dir);
let ts = chrono::Local::now().format("%Y%m%dT%H%M%S");
let path = log_dir.join(format!("{}.json", ts));
if std::fs::write(&path, json).is_ok() {
dbglog!(
"saved failed request to {} ({})", path.display(), reason
);
}
}
/// Read the next SSE event from the response stream.
/// Returns Ok(Some(value)) for each parsed data line,
/// Ok(None) when the stream ends or [DONE] is received.
pub(crate) async fn next_event(
&mut self,
response: &mut HttpResponse,
) -> Result<Option<serde_json::Value>> {
loop {
// Drain complete lines from the buffer before reading more chunks
while let Some(newline_pos) = self.line_buf.find('\n') {
let line = self.line_buf[..newline_pos].trim().to_string();
self.line_buf = self.line_buf[newline_pos + 1..].to_string();
if line == "data: [DONE]" {
self.done = true;
return Ok(None);
}
if line.is_empty()
|| line.starts_with("event: ")
|| !line.starts_with("data: ")
{
continue;
}
let json_str = &line[6..];
self.sse_lines_parsed += 1;
match serde_json::from_str(json_str) {
Ok(v) => return Ok(Some(v)),
Err(e) => {
self.sse_parse_errors += 1;
if self.sse_parse_errors == 1 || self.debug {
let preview = if json_str.len() > 200 {
format!("{}...", &json_str[..200])
} else {
json_str.to_string()
};
dbglog!(
"SSE parse error (#{}) {}: {}",
self.sse_parse_errors, e, preview
);
}
continue;
}
}
}
if self.done {
return Ok(None);
}
// Read more data from the response stream
match tokio::time::timeout(self.chunk_timeout, response.chunk()).await {
Ok(Ok(Some(chunk))) => {
self.chunks_received += 1;
self.line_buf.push_str(&String::from_utf8_lossy(&chunk));
}
Ok(Ok(None)) => return Ok(None),
Ok(Err(e)) => {
let buf_preview = if self.line_buf.is_empty() {
"(empty)".to_string()
} else {
let n = self.line_buf.len().min(500);
format!("{}B: {}", self.line_buf.len(), &self.line_buf[..n])
};
let msg = format!(
"stream error after {} chunks, {:.1}s, {} sse lines: {} | buf: {}",
self.chunks_received,
self.stream_start.elapsed().as_secs_f64(),
self.sse_lines_parsed,
e, buf_preview,
);
dbglog!("{}", msg);
self.save_failed_request(&msg);
return Err(e.into());
}
Err(_) => {
let buf_preview = if self.line_buf.is_empty() {
"(empty)".to_string()
} else {
let n = self.line_buf.len().min(500);
format!("{}B: {}", self.line_buf.len(), &self.line_buf[..n])
};
let msg = format!(
"stream timeout: {}s, {} chunks, {} sse lines, {:.1}s elapsed | buf: {}",
self.chunk_timeout.as_secs(),
self.chunks_received,
self.sse_lines_parsed,
self.stream_start.elapsed().as_secs_f64(),
buf_preview,
);
dbglog!("{}", msg);
self.save_failed_request(&msg);
anyhow::bail!(
"stream timeout: no data for {}s ({} chunks received)",
self.chunk_timeout.as_secs(),
self.chunks_received
);
}
}
}
}
}