call_api_with_tools_sync() -> src/agent/oneshot.rs

This commit is contained in:
Kent Overstreet 2026-04-07 00:57:35 -04:00
parent da24e02159
commit cbf7653cdf
9 changed files with 638 additions and 234 deletions

View file

@ -13,6 +13,27 @@ use crate::subconscious::{defs, prompts};
use std::fs;
use std::path::PathBuf;
use std::sync::OnceLock;
use super::api::ApiClient;
use super::api::types::*;
use super::tools::{self as agent_tools};
// ---------------------------------------------------------------------------
// API client — shared across oneshot agent runs
// ---------------------------------------------------------------------------
static API_CLIENT: OnceLock<ApiClient> = OnceLock::new();
fn get_client() -> Result<&'static ApiClient, String> {
Ok(API_CLIENT.get_or_init(|| {
let config = crate::config::get();
let base_url = config.api_base_url.as_deref().unwrap_or("");
let api_key = config.api_key.as_deref().unwrap_or("");
let model = config.api_model.as_deref().unwrap_or("qwen-2.5-27b");
ApiClient::new(base_url, api_key, model)
}))
}
// ---------------------------------------------------------------------------
// Agent execution
@ -143,7 +164,7 @@ pub fn run_one_agent(
Ok(())
};
let output = crate::subconscious::api::call_api_with_tools_sync(
let output = call_api_with_tools_sync(
agent_name, &prompts, &step_phases, def.temperature, def.priority,
&effective_tools, Some(&bail_fn), log)?;
@ -154,6 +175,201 @@ pub fn run_one_agent(
})
}
// ---------------------------------------------------------------------------
// Multi-step API turn loop
// ---------------------------------------------------------------------------
/// Run agent prompts through the API with tool support.
/// For multi-step agents, each prompt is injected as a new user message
/// after the previous step's tool loop completes. The conversation
/// context carries forward naturally between steps.
/// Returns the final text response after all steps complete.
pub async fn call_api_with_tools(
agent: &str,
prompts: &[String],
phases: &[String],
temperature: Option<f32>,
priority: i32,
tools: &[agent_tools::Tool],
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
log: &dyn Fn(&str),
) -> Result<String, String> {
let client = get_client()?;
let first_phase = phases.first().map(|s| s.as_str()).unwrap_or("");
let _provenance = std::cell::RefCell::new(
if first_phase.is_empty() { format!("agent:{}", agent) }
else { format!("agent:{}:{}", agent, first_phase) }
);
let mut messages = vec![Message::user(&prompts[0])];
let mut next_prompt_idx = 1;
let reasoning = crate::config::get().api_reasoning.clone();
let max_turns = 50 * prompts.len();
for turn in 0..max_turns {
log(&format!("\n=== TURN {} ({} messages) ===\n", turn, messages.len()));
let mut last_err = None;
let mut msg_opt = None;
let mut usage_opt = None;
for attempt in 0..5 {
let sampling = super::api::SamplingParams {
temperature: temperature.unwrap_or(0.6),
top_p: 0.95,
top_k: 20,
};
match client.chat_completion_stream_temp(
&messages,
tools,
&reasoning,
sampling,
Some(priority),
).await {
Ok((msg, usage)) => {
msg_opt = Some(msg);
usage_opt = usage;
break;
}
Err(e) => {
let err_str = e.to_string();
let is_transient = err_str.contains("IncompleteMessage")
|| err_str.contains("connection closed")
|| err_str.contains("connection reset")
|| err_str.contains("timed out")
|| err_str.contains("Connection refused");
if is_transient && attempt < 4 {
log(&format!("transient error (attempt {}): {}, retrying...",
attempt + 1, err_str));
tokio::time::sleep(std::time::Duration::from_secs(2 << attempt)).await;
last_err = Some(e);
continue;
}
let msg_bytes: usize = messages.iter()
.map(|m| m.content_text().len())
.sum();
return Err(format!(
"API error on turn {} (~{}KB payload, {} messages, {} attempts): {}",
turn, msg_bytes / 1024, messages.len(), attempt + 1, e));
}
}
}
let msg = msg_opt.unwrap();
if let Some(ref e) = last_err {
log(&format!("succeeded after retry (previous error: {})", e));
}
if let Some(u) = &usage_opt {
log(&format!("tokens: {} prompt + {} completion",
u.prompt_tokens, u.completion_tokens));
}
let has_content = msg.content.is_some();
let has_tools = msg.tool_calls.as_ref().is_some_and(|tc| !tc.is_empty());
if has_tools {
let mut sanitized = msg.clone();
if let Some(ref mut calls) = sanitized.tool_calls {
for call in calls {
if serde_json::from_str::<serde_json::Value>(&call.function.arguments).is_err() {
log(&format!("sanitizing malformed args for {}: {}",
call.function.name, &call.function.arguments));
call.function.arguments = "{}".to_string();
}
}
}
messages.push(sanitized);
for call in msg.tool_calls.as_ref().unwrap() {
log(&format!("\nTOOL CALL: {}({})",
call.function.name,
&call.function.arguments));
let args: serde_json::Value = match serde_json::from_str(&call.function.arguments) {
Ok(v) => v,
Err(_) => {
log(&format!("malformed tool call args: {}", &call.function.arguments));
messages.push(Message::tool_result(
&call.id,
"Error: your tool call had malformed JSON arguments. Please retry with valid JSON.",
));
continue;
}
};
let output = agent_tools::dispatch(&call.function.name, &args).await;
if std::env::var("POC_AGENT_VERBOSE").is_ok() {
log(&format!("TOOL RESULT ({} chars):\n{}", output.len(), output));
} else {
let preview: String = output.lines().next().unwrap_or("").chars().take(100).collect();
log(&format!("Result: {}", preview));
}
messages.push(Message::tool_result(&call.id, &output));
}
continue;
}
// Text-only response — step complete
let text = msg.content_text().to_string();
if text.is_empty() && !has_content {
log("empty response, retrying");
messages.push(Message::user(
"[system] Your previous response was empty. Please respond with text or use a tool."
));
continue;
}
log(&format!("\n=== RESPONSE ===\n\n{}", text));
// If there are more prompts, check bail condition and inject the next one
if next_prompt_idx < prompts.len() {
if let Some(ref check) = bail_fn {
check(next_prompt_idx)?;
}
if let Some(phase) = phases.get(next_prompt_idx) {
*_provenance.borrow_mut() = format!("agent:{}:{}", agent, phase);
}
messages.push(Message::assistant(&text));
let next = &prompts[next_prompt_idx];
next_prompt_idx += 1;
log(&format!("\n=== STEP {}/{} ===\n", next_prompt_idx, prompts.len()));
messages.push(Message::user(next));
continue;
}
return Ok(text);
}
Err(format!("agent exceeded {} tool turns", max_turns))
}
/// Synchronous wrapper — runs the async function on a dedicated thread
/// with its own tokio runtime. Safe to call from any context.
pub fn call_api_with_tools_sync(
agent: &str,
prompts: &[String],
phases: &[String],
temperature: Option<f32>,
priority: i32,
tools: &[agent_tools::Tool],
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
log: &(dyn Fn(&str) + Sync),
) -> Result<String, String> {
std::thread::scope(|s| {
s.spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| format!("tokio runtime: {}", e))?;
rt.block_on(
call_api_with_tools(agent, prompts, phases, temperature, priority, tools, bail_fn, log)
)
}).join().unwrap()
})
}
// ---------------------------------------------------------------------------
// Process management — PID tracking and subprocess spawning
// ---------------------------------------------------------------------------