diff --git a/src/agent/api/mod.rs b/src/agent/api/mod.rs index 6822dfd..695994c 100644 --- a/src/agent/api/mod.rs +++ b/src/agent/api/mod.rs @@ -133,6 +133,34 @@ impl ApiClient { (rx, AbortOnDrop(handle)) } + /// Start a streaming completion with raw token IDs. + /// No message formatting — the caller provides the complete prompt as tokens. + pub(crate) fn start_stream_completions( + &self, + prompt_tokens: &[u32], + sampling: SamplingParams, + priority: Option, + ) -> (mpsc::UnboundedReceiver, AbortOnDrop) { + let (tx, rx) = mpsc::unbounded_channel(); + let client = self.client.clone(); + let api_key = self.api_key.clone(); + let model = self.model.clone(); + let prompt_tokens = prompt_tokens.to_vec(); + let base_url = self.base_url.clone(); + + let handle = tokio::spawn(async move { + let result = openai::stream_completions( + &client, &base_url, &api_key, &model, + &prompt_tokens, &tx, sampling, priority, + ).await; + if let Err(e) = result { + let _ = tx.send(StreamEvent::Error(e.to_string())); + } + }); + + (rx, AbortOnDrop(handle)) + } + pub(crate) async fn chat_completion_stream_temp( &self, messages: &[Message], diff --git a/src/agent/api/openai.rs b/src/agent/api/openai.rs index abf992f..f4da2a6 100644 --- a/src/agent/api/openai.rs +++ b/src/agent/api/openai.rs @@ -185,3 +185,146 @@ pub(super) async fn stream_events( Ok(()) } + +/// Stream from the /v1/completions endpoint using raw token IDs. +/// Tool calls come as text ( tags) and are parsed by the caller. +/// Thinking content comes as tags and is split into Reasoning events. +pub(super) async fn stream_completions( + client: &HttpClient, + base_url: &str, + api_key: &str, + model: &str, + prompt_tokens: &[u32], + tx: &mpsc::UnboundedSender, + sampling: super::SamplingParams, + priority: Option, +) -> Result<()> { + let mut request = serde_json::json!({ + "model": model, + "prompt": prompt_tokens, + "max_tokens": 16384, + "temperature": sampling.temperature, + "top_p": sampling.top_p, + "top_k": sampling.top_k, + "stream": true, + "stop_token_ids": [super::super::tokenizer::IM_END], + }); + if let Some(p) = priority { + request["priority"] = serde_json::json!(p); + } + + let url = format!("{}/completions", base_url); + let debug_label = format!("{} prompt tokens, model={}", prompt_tokens.len(), model); + + let mut response = super::send_and_check( + client, + &url, + &request, + ("Authorization", &format!("Bearer {}", api_key)), + &[], + &debug_label, + None, + ) + .await?; + + let mut reader = super::SseReader::new(); + let mut content_len: usize = 0; + let mut first_content_at = None; + let mut finish_reason = None; + let mut usage = None; + let mut in_think = false; + + while let Some(event) = reader.next_event(&mut response).await? { + if let Some(err_msg) = event["error"]["message"].as_str() { + anyhow::bail!("API error in stream: {}", err_msg); + } + + // Completions chunks have a simpler structure + if let Some(u) = event["usage"].as_object() { + if let Ok(u) = serde_json::from_value::(serde_json::Value::Object(u.clone())) { + let _ = tx.send(StreamEvent::Usage(u.clone())); + usage = Some(u); + } + } + + let choices = match event["choices"].as_array() { + Some(c) => c, + None => continue, + }; + + for choice in choices { + if let Some(reason) = choice["finish_reason"].as_str() { + if reason != "null" { + finish_reason = Some(reason.to_string()); + } + } + + if let Some(text) = choice["text"].as_str() { + if text.is_empty() { continue; } + + // Handle tags — split into Reasoning vs Content + if text.contains("") || in_think { + // Simple state machine for think tags + let mut remaining = text; + while !remaining.is_empty() { + if in_think { + if let Some(end) = remaining.find("") { + let thinking = &remaining[..end]; + if !thinking.is_empty() { + let _ = tx.send(StreamEvent::Reasoning(thinking.to_string())); + } + remaining = &remaining[end + 8..]; + in_think = false; + } else { + let _ = tx.send(StreamEvent::Reasoning(remaining.to_string())); + break; + } + } else { + if let Some(start) = remaining.find("") { + let content = &remaining[..start]; + if !content.is_empty() { + content_len += content.len(); + if first_content_at.is_none() { + first_content_at = Some(reader.stream_start.elapsed()); + } + let _ = tx.send(StreamEvent::Content(content.to_string())); + } + remaining = &remaining[start + 7..]; + in_think = true; + } else { + content_len += remaining.len(); + if first_content_at.is_none() { + first_content_at = Some(reader.stream_start.elapsed()); + } + let _ = tx.send(StreamEvent::Content(remaining.to_string())); + break; + } + } + } + } else { + content_len += text.len(); + if first_content_at.is_none() { + first_content_at = Some(reader.stream_start.elapsed()); + } + let _ = tx.send(StreamEvent::Content(text.to_string())); + } + } + } + } + + let total_elapsed = reader.stream_start.elapsed(); + super::log_diagnostics( + content_len, 0, 0, "none", + &finish_reason, + reader.chunks_received, + reader.sse_lines_parsed, + reader.sse_parse_errors, + 0, total_elapsed, first_content_at, + &usage, &[], + ); + + let reason = finish_reason.unwrap_or_default(); + let _ = tx.send(StreamEvent::Finished { reason }); + + Ok(()) +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 8f021e7..6fcd403 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -483,19 +483,28 @@ impl Agent { let _thinking = start_activity(&agent, "thinking...").await; let (mut rx, _stream_guard) = { let me = agent.lock().await; - let api_messages = me.assemble_api_messages(); let sampling = api::SamplingParams { temperature: me.temperature, top_p: me.top_p, top_k: me.top_k, }; - me.client.start_stream( - &api_messages, - &me.tools, - &me.reasoning_effort, - sampling, - None, - ) + if tokenizer::is_initialized() { + let prompt_tokens = me.assemble_prompt_tokens(); + me.client.start_stream_completions( + &prompt_tokens, + sampling, + None, + ) + } else { + let api_messages = me.assemble_api_messages(); + me.client.start_stream( + &api_messages, + &me.tools, + &me.reasoning_effort, + sampling, + None, + ) + } }; // --- Lock released ---