AST-level pattern matching — find code by structure, not text.
e.g. find all `if let Some($X) = $Y { $$$BODY }` patterns.
Supports C, Rust, Python, JS/TS, Go, and 20+ languages.
Gracefully errors if sg binary isn't installed.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
236 lines
6.5 KiB
Rust
236 lines
6.5 KiB
Rust
// tools/mod.rs — Agent-specific tool dispatch
|
|
//
|
|
// Shared tools (memory, files, bash, journal) live in thought/.
|
|
// This module handles agent-specific tools (control, vision,
|
|
// working_stack) and delegates everything else to thought::dispatch.
|
|
|
|
// Core tools
|
|
mod ast_grep;
|
|
mod bash;
|
|
pub mod channels;
|
|
mod edit;
|
|
mod glob;
|
|
mod grep;
|
|
pub mod memory;
|
|
mod read;
|
|
mod web;
|
|
mod write;
|
|
|
|
// Agent-specific tools
|
|
mod control;
|
|
mod vision;
|
|
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
|
|
fn default_timeout() -> u64 { 120 }
|
|
|
|
pub type ToolHandler = Arc<dyn Fn(
|
|
Option<std::sync::Arc<super::Agent>>,
|
|
serde_json::Value,
|
|
) -> Pin<Box<dyn Future<Output = anyhow::Result<String>> + Send>>
|
|
+ Send + Sync>;
|
|
|
|
#[derive(Clone)]
|
|
pub struct Tool {
|
|
pub name: &'static str,
|
|
pub description: &'static str,
|
|
pub parameters_json: &'static str,
|
|
pub handler: ToolHandler,
|
|
}
|
|
|
|
impl Tool {
|
|
/// Build the JSON for this tool's definition (for the API tools array).
|
|
pub fn to_json(&self) -> String {
|
|
format!(
|
|
r#"{{"type":"function","function":{{"name":"{}","description":"{}","parameters":{}}}}}"#,
|
|
self.name,
|
|
self.description.replace('"', r#"\""#),
|
|
self.parameters_json,
|
|
)
|
|
}
|
|
}
|
|
|
|
pub struct ActiveToolCall {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub detail: String,
|
|
pub started: Instant,
|
|
pub background: bool,
|
|
pub handle: tokio::task::JoinHandle<(super::context::PendingToolCall, String)>,
|
|
}
|
|
|
|
pub struct ActiveTools(Vec<ActiveToolCall>);
|
|
|
|
impl ActiveTools {
|
|
pub fn new() -> Self { Self(Vec::new()) }
|
|
|
|
pub fn push(&mut self, call: ActiveToolCall) {
|
|
self.0.push(call);
|
|
}
|
|
|
|
pub fn remove(&mut self, id: &str) {
|
|
self.0.retain(|t| t.id != id);
|
|
}
|
|
|
|
pub fn take_finished(&mut self) -> Vec<ActiveToolCall> {
|
|
let mut finished = Vec::new();
|
|
let mut i = 0;
|
|
while i < self.0.len() {
|
|
if self.0[i].handle.is_finished() {
|
|
finished.push(self.0.remove(i));
|
|
} else {
|
|
i += 1;
|
|
}
|
|
}
|
|
finished
|
|
}
|
|
|
|
pub fn take_foreground(&mut self) -> Vec<ActiveToolCall> {
|
|
let mut fg = Vec::new();
|
|
let mut i = 0;
|
|
while i < self.0.len() {
|
|
if !self.0[i].background {
|
|
fg.push(self.0.remove(i));
|
|
} else {
|
|
i += 1;
|
|
}
|
|
}
|
|
fg
|
|
}
|
|
|
|
pub fn iter(&self) -> impl Iterator<Item = &ActiveToolCall> {
|
|
self.0.iter()
|
|
}
|
|
|
|
pub fn abort_all(&mut self) {
|
|
for entry in self.0.drain(..) {
|
|
entry.handle.abort();
|
|
}
|
|
}
|
|
|
|
pub fn len(&self) -> usize { self.0.len() }
|
|
pub fn is_empty(&self) -> bool { self.0.is_empty() }
|
|
}
|
|
|
|
/// Truncate output if it exceeds max length, appending a truncation notice.
|
|
pub fn truncate_output(mut s: String, max: usize) -> String {
|
|
if s.len() > max {
|
|
s.truncate(max);
|
|
s.push_str("\n... (output truncated)");
|
|
}
|
|
s
|
|
}
|
|
|
|
/// Dispatch a tool call by name through the registry.
|
|
/// Dispatch a tool call by name. Returns the result text,
|
|
/// or an error string prefixed with "Error: ".
|
|
pub async fn dispatch(
|
|
name: &str,
|
|
args: &serde_json::Value,
|
|
) -> String {
|
|
dispatch_with_agent(name, args, None).await
|
|
}
|
|
|
|
/// Dispatch a tool call with optional agent context.
|
|
/// If agent is provided, uses the agent's tool list.
|
|
pub async fn dispatch_with_agent(
|
|
name: &str,
|
|
args: &serde_json::Value,
|
|
agent: Option<std::sync::Arc<super::Agent>>,
|
|
) -> String {
|
|
let tool = if let Some(ref a) = agent {
|
|
// Only dispatch tools the agent is allowed to use
|
|
let guard = a.state.lock().await;
|
|
guard.tools.iter().find(|t| t.name == name).cloned()
|
|
} else {
|
|
// No agent context — allow all tools (CLI/MCP path)
|
|
tools().into_iter().find(|t| t.name == name)
|
|
};
|
|
match tool {
|
|
Some(t) => (t.handler)(agent, args.clone()).await
|
|
.unwrap_or_else(|e| format!("Error: {}", e)),
|
|
None => format!("Error: Unknown tool: {}", name),
|
|
}
|
|
}
|
|
|
|
/// Return all registered tools with definitions + handlers.
|
|
pub fn tools() -> Vec<Tool> {
|
|
let mut all = vec![
|
|
read::tool(), write::tool(), edit::tool(),
|
|
grep::tool(), glob::tool(), bash::tool(),
|
|
ast_grep::tool(), vision::tool(),
|
|
];
|
|
all.extend(web::tools());
|
|
all.extend(memory::memory_tools());
|
|
all.extend(memory::journal_tools());
|
|
all.extend(channels::tools());
|
|
all.extend(control::tools());
|
|
all
|
|
}
|
|
|
|
/// Memory + journal tools only — for subconscious agents.
|
|
pub fn memory_and_journal_tools() -> Vec<Tool> {
|
|
let mut all = memory::memory_tools().to_vec();
|
|
all.extend(memory::journal_tools());
|
|
all
|
|
}
|
|
|
|
/// Create a short summary of tool args for the tools pane header.
|
|
pub fn summarize_args(tool_name: &str, args: &serde_json::Value) -> String {
|
|
match tool_name {
|
|
"read_file" | "write_file" | "edit_file" => args["file_path"]
|
|
.as_str()
|
|
.unwrap_or("")
|
|
.to_string(),
|
|
"bash" => {
|
|
let cmd = args["command"].as_str().unwrap_or("");
|
|
if cmd.len() > 60 {
|
|
let end = cmd.char_indices()
|
|
.map(|(i, _)| i)
|
|
.take_while(|&i| i <= 60)
|
|
.last()
|
|
.unwrap_or(0);
|
|
format!("{}...", &cmd[..end])
|
|
} else {
|
|
cmd.to_string()
|
|
}
|
|
}
|
|
"grep" => {
|
|
let pattern = args["pattern"].as_str().unwrap_or("");
|
|
let path = args["path"].as_str().unwrap_or(".");
|
|
format!("{} in {}", pattern, path)
|
|
}
|
|
"glob" => args["pattern"]
|
|
.as_str()
|
|
.unwrap_or("")
|
|
.to_string(),
|
|
"view_image" => {
|
|
if let Some(pane) = args["pane_id"].as_str() {
|
|
format!("pane {}", pane)
|
|
} else {
|
|
args["file_path"].as_str().unwrap_or("").to_string()
|
|
}
|
|
}
|
|
"journal" => {
|
|
let entry = args["entry"].as_str().unwrap_or("");
|
|
if entry.len() > 60 {
|
|
format!("{}...", &entry[..60])
|
|
} else {
|
|
entry.to_string()
|
|
}
|
|
}
|
|
"yield_to_user" => args["message"]
|
|
.as_str()
|
|
.unwrap_or("")
|
|
.to_string(),
|
|
"switch_model" => args["model"]
|
|
.as_str()
|
|
.unwrap_or("")
|
|
.to_string(),
|
|
"pause" => String::new(),
|
|
_ => String::new(),
|
|
}
|
|
}
|