Collapse API layer: inline openai.rs, delete types.rs and parsing.rs
API is now two files: mod.rs (430 lines) and http.rs. Contains: Usage, StreamToken, SamplingParams, ApiClient, stream_completions, SseReader, send_and_check. Everything else is dead and gone. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
9bb626f18c
commit
22146156d4
3 changed files with 87 additions and 113 deletions
|
|
@ -7,10 +7,6 @@
|
||||||
// Set POC_DEBUG=1 for verbose per-turn logging.
|
// Set POC_DEBUG=1 for verbose per-turn logging.
|
||||||
|
|
||||||
pub mod http;
|
pub mod http;
|
||||||
mod types;
|
|
||||||
mod openai;
|
|
||||||
|
|
||||||
pub use types::Usage;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
@ -18,6 +14,14 @@ use std::time::{Duration, Instant};
|
||||||
use self::http::{HttpClient, HttpResponse};
|
use self::http::{HttpClient, HttpResponse};
|
||||||
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
/// A JoinHandle that aborts its task when dropped.
|
/// A JoinHandle that aborts its task when dropped.
|
||||||
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
pub(crate) struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||||
|
|
@ -70,8 +74,6 @@ impl ApiClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stream a completion with raw token IDs.
|
|
||||||
/// Returns (text, token_id) per token via channel.
|
|
||||||
pub(crate) fn stream_completion(
|
pub(crate) fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
prompt_tokens: &[u32],
|
prompt_tokens: &[u32],
|
||||||
|
|
@ -86,7 +88,7 @@ impl ApiClient {
|
||||||
let base_url = self.base_url.clone();
|
let base_url = self.base_url.clone();
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
let handle = tokio::spawn(async move {
|
||||||
let result = openai::stream_completions(
|
let result = stream_completions(
|
||||||
&client, &base_url, &api_key, &model,
|
&client, &base_url, &api_key, &model,
|
||||||
&prompt_tokens, &tx, sampling, priority,
|
&prompt_tokens, &tx, sampling, priority,
|
||||||
).await;
|
).await;
|
||||||
|
|
@ -103,7 +105,84 @@ impl ApiClient {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send an HTTP request and check for errors. Shared by both backends.
|
async fn stream_completions(
|
||||||
|
client: &HttpClient,
|
||||||
|
base_url: &str,
|
||||||
|
api_key: &str,
|
||||||
|
model: &str,
|
||||||
|
prompt_tokens: &[u32],
|
||||||
|
tx: &mpsc::UnboundedSender<StreamToken>,
|
||||||
|
sampling: SamplingParams,
|
||||||
|
priority: Option<i32>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let mut request = serde_json::json!({
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt_tokens,
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"temperature": sampling.temperature,
|
||||||
|
"top_p": sampling.top_p,
|
||||||
|
"top_k": sampling.top_k,
|
||||||
|
"stream": true,
|
||||||
|
"return_token_ids": true,
|
||||||
|
"skip_special_tokens": false,
|
||||||
|
"stop_token_ids": [super::tokenizer::IM_END],
|
||||||
|
});
|
||||||
|
if let Some(p) = priority {
|
||||||
|
request["priority"] = serde_json::json!(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = format!("{}/completions", base_url);
|
||||||
|
let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model);
|
||||||
|
|
||||||
|
let mut response = send_and_check(
|
||||||
|
client, &url, &request,
|
||||||
|
("Authorization", &format!("Bearer {}", api_key)),
|
||||||
|
&[], &debug_label, None,
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
let mut reader = SseReader::new();
|
||||||
|
let mut usage = None;
|
||||||
|
|
||||||
|
while let Some(event) = reader.next_event(&mut response).await? {
|
||||||
|
if let Some(err_msg) = event["error"]["message"].as_str() {
|
||||||
|
anyhow::bail!("API error in stream: {}", err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(u) = event["usage"].as_object() {
|
||||||
|
if let Ok(u) = serde_json::from_value::<Usage>(serde_json::Value::Object(u.clone())) {
|
||||||
|
usage = Some(u);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let choices = match event["choices"].as_array() {
|
||||||
|
Some(c) => c,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
for choice in choices {
|
||||||
|
let text = choice["text"].as_str().unwrap_or("");
|
||||||
|
let token_ids = choice["token_ids"].as_array();
|
||||||
|
|
||||||
|
if let Some(ids) = token_ids {
|
||||||
|
for (i, id_val) in ids.iter().enumerate() {
|
||||||
|
if let Some(id) = id_val.as_u64() {
|
||||||
|
let _ = tx.send(StreamToken::Token {
|
||||||
|
text: if i == 0 { text.to_string() } else { String::new() },
|
||||||
|
id: id as u32,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !text.is_empty() {
|
||||||
|
let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = tx.send(StreamToken::Done { usage });
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send an HTTP request and check for errors.
|
||||||
pub(crate) async fn send_and_check(
|
pub(crate) async fn send_and_check(
|
||||||
client: &HttpClient,
|
client: &HttpClient,
|
||||||
url: &str,
|
url: &str,
|
||||||
|
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
// api/openai.rs — OpenAI-compatible backend
|
|
||||||
//
|
|
||||||
// Works with any provider that implements the OpenAI chat completions
|
|
||||||
// API: OpenRouter, vLLM, llama.cpp, Fireworks, Together, etc.
|
|
||||||
// Also used for local models (Qwen, llama) via compatible servers.
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
use super::http::HttpClient;
|
|
||||||
use super::types::*;
|
|
||||||
use super::StreamToken;
|
|
||||||
|
|
||||||
/// Stream from /v1/completions with raw token IDs in and out.
|
|
||||||
/// Each SSE chunk yields one token (text + id). All parsing (think tags,
|
|
||||||
/// tool calls) is handled by the ResponseParser, not here.
|
|
||||||
pub(super) async fn stream_completions(
|
|
||||||
client: &HttpClient,
|
|
||||||
base_url: &str,
|
|
||||||
api_key: &str,
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: &[u32],
|
|
||||||
tx: &mpsc::UnboundedSender<StreamToken>,
|
|
||||||
sampling: super::SamplingParams,
|
|
||||||
priority: Option<i32>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut request = serde_json::json!({
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt_tokens,
|
|
||||||
"max_tokens": 16384,
|
|
||||||
"temperature": sampling.temperature,
|
|
||||||
"top_p": sampling.top_p,
|
|
||||||
"top_k": sampling.top_k,
|
|
||||||
"stream": true,
|
|
||||||
"return_token_ids": true,
|
|
||||||
"skip_special_tokens": false,
|
|
||||||
"stop_token_ids": [super::super::tokenizer::IM_END],
|
|
||||||
});
|
|
||||||
if let Some(p) = priority {
|
|
||||||
request["priority"] = serde_json::json!(p);
|
|
||||||
}
|
|
||||||
|
|
||||||
let url = format!("{}/completions", base_url);
|
|
||||||
let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model);
|
|
||||||
|
|
||||||
let mut response = super::send_and_check(
|
|
||||||
client,
|
|
||||||
&url,
|
|
||||||
&request,
|
|
||||||
("Authorization", &format!("Bearer {}", api_key)),
|
|
||||||
&[],
|
|
||||||
&debug_label,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut reader = super::SseReader::new();
|
|
||||||
let mut usage = None;
|
|
||||||
|
|
||||||
while let Some(event) = reader.next_event(&mut response).await? {
|
|
||||||
if let Some(err_msg) = event["error"]["message"].as_str() {
|
|
||||||
anyhow::bail!("API error in stream: {}", err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(u) = event["usage"].as_object() {
|
|
||||||
if let Ok(u) = serde_json::from_value::<Usage>(serde_json::Value::Object(u.clone())) {
|
|
||||||
usage = Some(u);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let choices = match event["choices"].as_array() {
|
|
||||||
Some(c) => c,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
for choice in choices {
|
|
||||||
let text = choice["text"].as_str().unwrap_or("");
|
|
||||||
let token_ids = choice["token_ids"].as_array();
|
|
||||||
|
|
||||||
if let Some(ids) = token_ids {
|
|
||||||
for (i, id_val) in ids.iter().enumerate() {
|
|
||||||
if let Some(id) = id_val.as_u64() {
|
|
||||||
let _ = tx.send(StreamToken::Token {
|
|
||||||
text: if i == 0 { text.to_string() } else { String::new() },
|
|
||||||
id: id as u32,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if !text.is_empty() {
|
|
||||||
let _ = tx.send(StreamToken::Token { text: text.to_string(), id: 0 });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = tx.send(StreamToken::Done { usage });
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
use serde::Deserialize;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
}
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue