consciousness/src/agent/oneshot.rs

449 lines
17 KiB
Rust
Raw Normal View History

// oneshot.rs — One-shot agent execution
//
// Runs an agent definition (from agents/*.agent files) through the API:
// build prompt → call LLM with tools → return result. Agents apply
// changes via tool calls during the LLM call — no action parsing needed.
//
// This is distinct from the interactive agent loop in agent/mod.rs:
// oneshot agents run a fixed prompt sequence and exit, while the
// interactive agent has a turn loop with streaming and TUI.
use crate::store::{self, Store};
use crate::subconscious::{defs, prompts};
use std::fs;
use std::path::PathBuf;
use std::sync::OnceLock;
use super::api::ApiClient;
use super::api::types::*;
use super::tools::{self as agent_tools};
// ---------------------------------------------------------------------------
// API client — shared across oneshot agent runs
// ---------------------------------------------------------------------------
static API_CLIENT: OnceLock<ApiClient> = OnceLock::new();
fn get_client() -> Result<&'static ApiClient, String> {
Ok(API_CLIENT.get_or_init(|| {
let config = crate::config::get();
let base_url = config.api_base_url.as_deref().unwrap_or("");
let api_key = config.api_key.as_deref().unwrap_or("");
let model = config.api_model.as_deref().unwrap_or("qwen-2.5-27b");
ApiClient::new(base_url, api_key, model)
}))
}
// ---------------------------------------------------------------------------
// Agent execution
// ---------------------------------------------------------------------------
/// Result of running a single agent.
pub struct AgentResult {
pub output: String,
pub node_keys: Vec<String>,
/// Directory containing output() files from the agent run.
pub state_dir: PathBuf,
}
/// Run an agent. If keys are provided, use them directly (bypassing the
/// agent's query). Otherwise, run the query to select target nodes.
pub fn run_one_agent(
store: &mut Store,
agent_name: &str,
count: usize,
keys: Option<&[String]>,
log: &(dyn Fn(&str) + Sync),
) -> Result<AgentResult, String> {
let def = defs::get_def(agent_name)
.ok_or_else(|| format!("no .agent file for {}", agent_name))?;
// State dir for agent output files
let state_dir = std::env::var("POC_AGENT_OUTPUT_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| store::memory_dir().join("agent-output").join(agent_name));
fs::create_dir_all(&state_dir)
.map_err(|e| format!("create state dir: {}", e))?;
unsafe { std::env::set_var("POC_AGENT_OUTPUT_DIR", &state_dir); }
// Build prompt batch — either from explicit keys or the agent's query
let agent_batch = if let Some(keys) = keys {
log(&format!("targeting: {}", keys.join(", ")));
let graph = store.build_graph();
let mut resolved_steps = Vec::new();
let mut all_keys: Vec<String> = keys.to_vec();
for step in &def.steps {
let (prompt, extra_keys) = defs::resolve_placeholders(
&step.prompt, store, &graph, keys, count,
);
all_keys.extend(extra_keys);
resolved_steps.push(prompts::ResolvedStep {
prompt,
phase: step.phase.clone(),
});
}
let batch = prompts::AgentBatch { steps: resolved_steps, node_keys: all_keys };
if !batch.node_keys.is_empty() {
store.record_agent_visits(&batch.node_keys, agent_name).ok();
}
batch
} else {
log("building prompt");
let effective_count = def.count.unwrap_or(count);
defs::run_agent(store, &def, effective_count, &Default::default())?
};
// Filter tools based on agent def
let all_tools = super::tools::memory_and_journal_tools();
let effective_tools: Vec<super::tools::Tool> = if def.tools.is_empty() {
all_tools.to_vec()
} else {
all_tools.into_iter()
.filter(|t| def.tools.iter().any(|w| w == &t.name))
.collect()
};
let tools_desc = effective_tools.iter().map(|t| t.name).collect::<Vec<_>>().join(", ");
let n_steps = agent_batch.steps.len();
for key in &agent_batch.node_keys {
log(&format!(" node: {}", key));
}
// Guard: reject oversized first prompt
let max_prompt_bytes = 800_000;
let first_len = agent_batch.steps[0].prompt.len();
if first_len > max_prompt_bytes {
let prompt_kb = first_len / 1024;
let oversize_dir = store::memory_dir().join("llm-logs").join("oversized");
fs::create_dir_all(&oversize_dir).ok();
let oversize_path = oversize_dir.join(format!("{}-{}.txt",
agent_name, store::compact_timestamp()));
let header = format!("=== OVERSIZED PROMPT ===\nagent: {}\nsize: {}KB (max {}KB)\nnodes: {:?}\n\n",
agent_name, prompt_kb, max_prompt_bytes / 1024, agent_batch.node_keys);
fs::write(&oversize_path, format!("{}{}", header, &agent_batch.steps[0].prompt)).ok();
log(&format!("oversized prompt logged to {}", oversize_path.display()));
return Err(format!(
"prompt too large: {}KB (max {}KB) — seed nodes may be oversized",
prompt_kb, max_prompt_bytes / 1024,
));
}
let phases: Vec<&str> = agent_batch.steps.iter().map(|s| s.phase.as_str()).collect();
log(&format!("{} step(s) {:?}, {}KB initial, {}, {} nodes, output={}",
n_steps, phases, first_len / 1024, tools_desc,
agent_batch.node_keys.len(), state_dir.display()));
let prompts: Vec<String> = agent_batch.steps.iter()
.map(|s| s.prompt.clone()).collect();
let step_phases: Vec<String> = agent_batch.steps.iter()
.map(|s| s.phase.clone()).collect();
if std::env::var("POC_AGENT_VERBOSE").is_ok() {
for (i, s) in agent_batch.steps.iter().enumerate() {
log(&format!("=== PROMPT {}/{} ({}) ===\n\n{}", i + 1, n_steps, s.phase, s.prompt));
}
}
log("\n=== CALLING LLM ===");
// Bail check: if the agent defines a bail script, run it between steps.
let bail_script = def.bail.as_ref().map(|name| defs::agents_dir().join(name));
let state_dir_for_bail = state_dir.clone();
let bail_fn = move |step_idx: usize| -> Result<(), String> {
if let Some(ref script) = bail_script {
let status = std::process::Command::new(script)
.current_dir(&state_dir_for_bail)
.status()
.map_err(|e| format!("bail script {:?} failed: {}", script, e))?;
if !status.success() {
return Err(format!("bailed at step {}: {:?} exited {}",
step_idx + 1, script.file_name().unwrap_or_default(),
status.code().unwrap_or(-1)));
}
}
Ok(())
};
let output = call_api_with_tools_sync(
agent_name, &prompts, &step_phases, def.temperature, def.priority,
&effective_tools, Some(&bail_fn), log)?;
Ok(AgentResult {
output,
node_keys: agent_batch.node_keys,
state_dir,
})
}
// ---------------------------------------------------------------------------
// Multi-step API turn loop
// ---------------------------------------------------------------------------
/// Run agent prompts through the API with tool support.
/// For multi-step agents, each prompt is injected as a new user message
/// after the previous step's tool loop completes. The conversation
/// context carries forward naturally between steps.
/// Returns the final text response after all steps complete.
pub async fn call_api_with_tools(
agent: &str,
prompts: &[String],
phases: &[String],
temperature: Option<f32>,
priority: i32,
tools: &[agent_tools::Tool],
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
log: &dyn Fn(&str),
) -> Result<String, String> {
let client = get_client()?;
let first_phase = phases.first().map(|s| s.as_str()).unwrap_or("");
let _provenance = std::cell::RefCell::new(
if first_phase.is_empty() { format!("agent:{}", agent) }
else { format!("agent:{}:{}", agent, first_phase) }
);
let mut messages = vec![Message::user(&prompts[0])];
let mut next_prompt_idx = 1;
let reasoning = crate::config::get().api_reasoning.clone();
let max_turns = 50 * prompts.len();
for turn in 0..max_turns {
log(&format!("\n=== TURN {} ({} messages) ===\n", turn, messages.len()));
let mut last_err = None;
let mut msg_opt = None;
let mut usage_opt = None;
for attempt in 0..5 {
let sampling = super::api::SamplingParams {
temperature: temperature.unwrap_or(0.6),
top_p: 0.95,
top_k: 20,
};
match client.chat_completion_stream_temp(
&messages,
tools,
&reasoning,
sampling,
Some(priority),
).await {
Ok((msg, usage)) => {
msg_opt = Some(msg);
usage_opt = usage;
break;
}
Err(e) => {
let err_str = e.to_string();
let is_transient = err_str.contains("IncompleteMessage")
|| err_str.contains("connection closed")
|| err_str.contains("connection reset")
|| err_str.contains("timed out")
|| err_str.contains("Connection refused");
if is_transient && attempt < 4 {
log(&format!("transient error (attempt {}): {}, retrying...",
attempt + 1, err_str));
tokio::time::sleep(std::time::Duration::from_secs(2 << attempt)).await;
last_err = Some(e);
continue;
}
let msg_bytes: usize = messages.iter()
.map(|m| m.content_text().len())
.sum();
return Err(format!(
"API error on turn {} (~{}KB payload, {} messages, {} attempts): {}",
turn, msg_bytes / 1024, messages.len(), attempt + 1, e));
}
}
}
let msg = msg_opt.unwrap();
if let Some(ref e) = last_err {
log(&format!("succeeded after retry (previous error: {})", e));
}
if let Some(u) = &usage_opt {
log(&format!("tokens: {} prompt + {} completion",
u.prompt_tokens, u.completion_tokens));
}
let has_content = msg.content.is_some();
let has_tools = msg.tool_calls.as_ref().is_some_and(|tc| !tc.is_empty());
if has_tools {
let mut sanitized = msg.clone();
if let Some(ref mut calls) = sanitized.tool_calls {
for call in calls {
if serde_json::from_str::<serde_json::Value>(&call.function.arguments).is_err() {
log(&format!("sanitizing malformed args for {}: {}",
call.function.name, &call.function.arguments));
call.function.arguments = "{}".to_string();
}
}
}
messages.push(sanitized);
for call in msg.tool_calls.as_ref().unwrap() {
log(&format!("\nTOOL CALL: {}({})",
call.function.name,
&call.function.arguments));
let args: serde_json::Value = match serde_json::from_str(&call.function.arguments) {
Ok(v) => v,
Err(_) => {
log(&format!("malformed tool call args: {}", &call.function.arguments));
messages.push(Message::tool_result(
&call.id,
"Error: your tool call had malformed JSON arguments. Please retry with valid JSON.",
));
continue;
}
};
let output = agent_tools::dispatch(&call.function.name, &args).await;
if std::env::var("POC_AGENT_VERBOSE").is_ok() {
log(&format!("TOOL RESULT ({} chars):\n{}", output.len(), output));
} else {
let preview: String = output.lines().next().unwrap_or("").chars().take(100).collect();
log(&format!("Result: {}", preview));
}
messages.push(Message::tool_result(&call.id, &output));
}
continue;
}
// Text-only response — step complete
let text = msg.content_text().to_string();
if text.is_empty() && !has_content {
log("empty response, retrying");
messages.push(Message::user(
"[system] Your previous response was empty. Please respond with text or use a tool."
));
continue;
}
log(&format!("\n=== RESPONSE ===\n\n{}", text));
// If there are more prompts, check bail condition and inject the next one
if next_prompt_idx < prompts.len() {
if let Some(ref check) = bail_fn {
check(next_prompt_idx)?;
}
if let Some(phase) = phases.get(next_prompt_idx) {
*_provenance.borrow_mut() = format!("agent:{}:{}", agent, phase);
}
messages.push(Message::assistant(&text));
let next = &prompts[next_prompt_idx];
next_prompt_idx += 1;
log(&format!("\n=== STEP {}/{} ===\n", next_prompt_idx, prompts.len()));
messages.push(Message::user(next));
continue;
}
return Ok(text);
}
Err(format!("agent exceeded {} tool turns", max_turns))
}
/// Synchronous wrapper — runs the async function on a dedicated thread
/// with its own tokio runtime. Safe to call from any context.
pub fn call_api_with_tools_sync(
agent: &str,
prompts: &[String],
phases: &[String],
temperature: Option<f32>,
priority: i32,
tools: &[agent_tools::Tool],
bail_fn: Option<&(dyn Fn(usize) -> Result<(), String> + Sync)>,
log: &(dyn Fn(&str) + Sync),
) -> Result<String, String> {
std::thread::scope(|s| {
s.spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| format!("tokio runtime: {}", e))?;
rt.block_on(
call_api_with_tools(agent, prompts, phases, temperature, priority, tools, bail_fn, log)
)
}).join().unwrap()
})
}
// ---------------------------------------------------------------------------
// Process management — PID tracking and subprocess spawning
// ---------------------------------------------------------------------------
/// Check for live agent processes in a state dir. Returns (phase, pid) pairs.
/// Cleans up stale pid files and kills timed-out processes.
pub fn scan_pid_files(state_dir: &std::path::Path, timeout_secs: u64) -> Vec<(String, u32)> {
let mut live = Vec::new();
let Ok(entries) = fs::read_dir(state_dir) else { return live };
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
let Some(pid_str) = name_str.strip_prefix("pid-") else { continue };
let Ok(pid) = pid_str.parse::<u32>() else { continue };
if unsafe { libc::kill(pid as i32, 0) } != 0 {
fs::remove_file(entry.path()).ok();
continue;
}
if timeout_secs > 0 {
if let Ok(meta) = entry.metadata() {
if let Ok(modified) = meta.modified() {
if modified.elapsed().unwrap_or_default().as_secs() > timeout_secs {
unsafe { libc::kill(pid as i32, libc::SIGTERM); }
fs::remove_file(entry.path()).ok();
continue;
}
}
}
}
let phase = fs::read_to_string(entry.path())
.unwrap_or_default()
.trim().to_string();
live.push((phase, pid));
}
live
}
pub struct SpawnResult {
pub child: std::process::Child,
pub log_path: PathBuf,
}
pub fn spawn_agent(
agent_name: &str,
state_dir: &std::path::Path,
session_id: &str,
) -> Option<SpawnResult> {
let def = defs::get_def(agent_name)?;
let first_phase = def.steps.first()
.map(|s| s.phase.as_str())
.unwrap_or("step-0");
let log_dir = dirs::home_dir().unwrap_or_default()
.join(format!(".consciousness/logs/{}", agent_name));
fs::create_dir_all(&log_dir).ok();
let log_path = log_dir.join(format!("{}.log", store::compact_timestamp()));
let agent_log = fs::File::create(&log_path)
.unwrap_or_else(|_| fs::File::create("/dev/null").unwrap());
let child = std::process::Command::new("poc-memory")
.args(["agent", "run", agent_name, "--count", "1", "--local",
"--state-dir", &state_dir.to_string_lossy()])
.env("POC_SESSION_ID", session_id)
.stdout(agent_log.try_clone().unwrap_or_else(|_| fs::File::create("/dev/null").unwrap()))
.stderr(agent_log)
.spawn()
.ok()?;
let pid = child.id();
let pid_path = state_dir.join(format!("pid-{}", pid));
fs::write(&pid_path, first_phase).ok();
Some(SpawnResult { child, log_path })
}