diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index d336816..dc9f0fd 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -7,10 +7,6 @@ // Set POC_DEBUG=1 for verbose per-turn logging. pub mod http; -mod types; -mod openai; - -pub use types::Usage; use anyhow::Result; use std::time::{Duration, Instant}; @@ -18,6 +14,14 @@ use std::time::{Duration, Instant}; use self::http::{HttpClient, HttpResponse}; use tokio::sync::mpsc; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} /// A JoinHandle that aborts its task when dropped. pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>); @@ -70,8 +74,6 @@ impl ApiClient { } } - /// Stream a completion with raw token IDs. - /// Returns (text, token_id) per token via channel. pub(crate) fn stream_completion( &self, prompt_tokens: &[u32], @@ -86,7 +88,7 @@ impl ApiClient { let base_url = self.base_url.clone(); let handle = tokio::spawn(async move { - let result = openai::stream_completions( + let result = stream_completions( &client, &base_url, &api_key, &model, &prompt_tokens, &tx, sampling, priority, ).await; @@ -103,7 +105,84 @@ impl ApiClient { } -/// Send an HTTP request and check for errors. Shared by both backends. +async fn stream_completions( + client: &HttpClient, + base_url: &str, + api_key: &str, + model: &str, + prompt_tokens: &[u32], + tx: &mpsc::UnboundedSender, + sampling: SamplingParams, + priority: Option, +) -> anyhow::Result<()> { + let mut request = serde_json::json!({ + "model": model, + "prompt": prompt_tokens, + "max_tokens": 16384, + "temperature": sampling.temperature, + "top_p": sampling.top_p, + "top_k": sampling.top_k, + "stream": true, + "return_token_ids": true, + "skip_special_tokens": false, + "stop_token_ids": [super::tokenizer::IM_END], + }); + if let Some(p) = priority { + request["priority"] = serde_json::json!(p); + } + + let url = format!("{}/completions", base_url); + let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model); + + let mut response = send_and_check( + client, &url, &request, + ("Authorization", &format!("Bearer {}", api_key)), + &[], &debug_label, None, + ).await?; + + let mut reader = SseReader::new(); + let mut usage = None; + + while let Some(event) = reader.next_event(&mut response).await? { + if let Some(err_msg) = event["error"]["message"].as_str() { + anyhow::bail!("API error in stream: {}", err_msg); + } + + if let Some(u) = event["usage"].as_object() { + if let Ok(u) = serde_json::from_value::(serde_json::Value::Object(u.clone())) { + usage = Some(u); + } + } + + let choices = match event["choices"].as_array() { + Some(c) => c, + None => continue, + }; + + for choice in choices { + let text = choice["text"].as_str().unwrap_or(""); + let token_ids = choice["token_ids"].as_array(); + + if let Some(ids) = token_ids { + for (i, id_val) in ids.iter().enumerate() { + if let Some(id) = id_val.as_u64() { + let _ = tx.send(StreamToken::Token { + text: if i == 0 { text.to_string() } else { String::new() }, + id: id as u32, + }); + } + } + } else if !text.is_empty() { + let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 }); + } + } + } + + let _ = tx.send(StreamToken::Done { usage }); + Ok(()) +} + +/// Send an HTTP request and check for errors. pub(crate) async fn send_and_check( client: &HttpClient, url: &str, diff --git a/src/agent/api/openai.rs b/src/agent/api/openai.rs deleted file mode 100644 index 6577037..0000000 --- a/src/agent/api/openai.rs +++ /dev/null @@ -1,97 +0,0 @@ -// api/openai.rs — OpenAI-compatible backend -// -// Works with any provider that implements the OpenAI chat completions -// API: OpenRouter, vLLM, llama.cpp, Fireworks, Together, etc. -// Also used for local models (Qwen, llama) via compatible servers. - -use anyhow::Result; -use tokio::sync::mpsc; - -use super::http::HttpClient; -use super::types::*; -use super::StreamToken; - -/// Stream from /v1/completions with raw token IDs in and out. -/// Each SSE chunk yields one token (text + id). All parsing (think tags, -/// tool calls) is handled by the ResponseParser, not here. -pub(super) async fn stream_completions( - client: &HttpClient, - base_url: &str, - api_key: &str, - model: &str, - prompt_tokens: &[u32], - tx: &mpsc::UnboundedSender, - sampling: super::SamplingParams, - priority: Option, -) -> Result<()> { - let mut request = serde_json::json!({ - "model": model, - "prompt": prompt_tokens, - "max_tokens": 16384, - "temperature": sampling.temperature, - "top_p": sampling.top_p, - "top_k": sampling.top_k, - "stream": true, - "return_token_ids": true, - "skip_special_tokens": false, - "stop_token_ids": [super::super::tokenizer::IM_END], - }); - if let Some(p) = priority { - request["priority"] = serde_json::json!(p); - } - - let url = format!("{}/completions", base_url); - let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model); - - let mut response = super::send_and_check( - client, - &url, - &request, - ("Authorization", &format!("Bearer {}", api_key)), - &[], - &debug_label, - None, - ) - .await?; - - let mut reader = super::SseReader::new(); - let mut usage = None; - - while let Some(event) = reader.next_event(&mut response).await? { - if let Some(err_msg) = event["error"]["message"].as_str() { - anyhow::bail!("API error in stream: {}", err_msg); - } - - if let Some(u) = event["usage"].as_object() { - if let Ok(u) = serde_json::from_value::(serde_json::Value::Object(u.clone())) { - usage = Some(u); - } - } - - let choices = match event["choices"].as_array() { - Some(c) => c, - None => continue, - }; - - for choice in choices { - let text = choice["text"].as_str().unwrap_or(""); - let token_ids = choice["token_ids"].as_array(); - - if let Some(ids) = token_ids { - for (i, id_val) in ids.iter().enumerate() { - if let Some(id) = id_val.as_u64() { - let _ = tx.send(StreamToken::Token { - text: if i == 0 { text.to_string() } else { String::new() }, - id: id as u32, - }); - } - } - } else if !text.is_empty() { - let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 }); - } - } - } - - let _ = tx.send(StreamToken::Done { usage }); - Ok(()) -} diff --git a/src/agent/api/types.rs b/src/agent/api/types.rs deleted file mode 100644 index 8b000af..0000000 --- a/src/agent/api/types.rs +++ /dev/null @@ -1,8 +0,0 @@ -use serde::Deserialize; - -#[derive(Debug, Clone, Deserialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -}