Add ast_grep tool: structural code search via ast-grep

AST-level pattern matching — find code by structure, not text.
e.g. find all `if let Some($X) = $Y { $$$BODY }` patterns.
Supports C, Rust, Python, JS/TS, Go, and 20+ languages.

Gracefully errors if sg binary isn't installed.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
ProofOfConcept 2026-04-09 02:57:02 -04:00 committed by Kent Overstreet
parent c53c4f9071
commit ec7e11db56
5 changed files with 552 additions and 5 deletions

146
src/agent/tools/ast_grep.rs Normal file
View file

@ -0,0 +1,146 @@
// tools/ast_grep.rs — Structural code search using ast-grep library
//
// AST-level pattern matching: find code structures, not just text.
// Uses ast-grep-core and ast-grep-language directly — no shell subprocess.
use std::sync::Arc;
use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;
use ast_grep_core::Pattern;
use ast_grep_language::{SupportLang, LanguageExt};
#[derive(Deserialize)]
struct Args {
pattern: String,
#[serde(default = "default_path")]
path: String,
lang: Option<String>,
}
fn default_path() -> String { ".".into() }
pub fn tool() -> super::Tool {
super::Tool {
name: "ast_grep",
description: "Structural code search using AST patterns. Finds code by structure, not text — \
e.g. find all `if let Some($X) = $Y { $$$BODY }` patterns. \
Supports C, Rust, Python, JS/TS, Go, Java, and 20+ languages.",
parameters_json: r#"{"type":"object","properties":{"pattern":{"type":"string","description":"AST pattern to search for. Use $X for single node wildcards, $$$X for multiple nodes."},"path":{"type":"string","description":"Directory or file to search in (default: current directory)"},"lang":{"type":"string","description":"Language (e.g. 'rust', 'c', 'python', 'javascript'). Auto-detected from file extension if omitted."}},"required":["pattern"]}"#,
handler: Arc::new(|_a, v| Box::pin(async move { ast_grep_search(&v) })),
}
}
fn detect_lang(path: &Path) -> Option<SupportLang> {
let ext = path.extension()?.to_str()?;
parse_lang(ext)
}
fn parse_lang(name: &str) -> Option<SupportLang> {
// ast-grep-language provides from_extension but we want from name
match name.to_lowercase().as_str() {
"rust" | "rs" => Some(SupportLang::Rust),
"c" => Some(SupportLang::C),
"cpp" | "c++" | "cc" | "cxx" => Some(SupportLang::Cpp),
"python" | "py" => Some(SupportLang::Python),
"javascript" | "js" => Some(SupportLang::JavaScript),
"typescript" | "ts" => Some(SupportLang::TypeScript),
"go" => Some(SupportLang::Go),
"java" => Some(SupportLang::Java),
"json" => Some(SupportLang::Json),
"html" => Some(SupportLang::Html),
"css" => Some(SupportLang::Css),
"bash" | "sh" => Some(SupportLang::Bash),
"ruby" | "rb" => Some(SupportLang::Ruby),
"yaml" | "yml" => Some(SupportLang::Yaml),
"lua" => Some(SupportLang::Lua),
"kotlin" | "kt" => Some(SupportLang::Kotlin),
"swift" => Some(SupportLang::Swift),
"scala" => Some(SupportLang::Scala),
_ => None,
}
}
fn search_file(
path: &Path,
lang: SupportLang,
pattern: &Pattern,
results: &mut Vec<String>,
) -> Result<()> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("reading {}", path.display()))?;
let tree = lang.ast_grep(&source);
for node_match in tree.root().find_all(pattern) {
let start = node_match.start_pos();
let line = start.line() + 1;
let matched_text = node_match.text();
let preview = if matched_text.len() > 200 {
format!("{}...", &matched_text[..200])
} else {
matched_text.to_string()
};
results.push(format!("{}:{}: {}", path.display(), line, preview));
}
Ok(())
}
fn walk_and_search(
dir: &Path,
explicit_lang: Option<SupportLang>,
pattern_str: &str,
results: &mut Vec<String>,
) -> Result<()> {
if dir.is_file() {
let lang = explicit_lang
.or_else(|| detect_lang(dir))
.ok_or_else(|| anyhow::anyhow!("cannot detect language for {}", dir.display()))?;
let pattern = Pattern::new(pattern_str, lang);
return search_file(dir, lang, &pattern, results);
}
for entry in walkdir::WalkDir::new(dir)
.into_iter()
.filter_entry(|e| {
let name = e.file_name().to_str().unwrap_or("");
!name.starts_with('.') && name != "target" && name != "node_modules"
})
{
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
if !entry.file_type().is_file() { continue; }
let path = entry.path();
let lang = match explicit_lang.or_else(|| detect_lang(path)) {
Some(l) => l,
None => continue,
};
let pattern = Pattern::new(pattern_str, lang);
let _ = search_file(path, lang, &pattern, results);
if results.len() >= 100 {
results.push("... (truncated at 100 matches)".into());
break;
}
}
Ok(())
}
fn ast_grep_search(args: &serde_json::Value) -> Result<String> {
let a: Args = serde_json::from_value(args.clone())
.context("invalid ast_grep arguments")?;
let explicit_lang = a.lang.as_deref().and_then(parse_lang);
let path = Path::new(&a.path);
let mut results = Vec::new();
walk_and_search(path, explicit_lang, &a.pattern, &mut results)?;
if results.is_empty() {
return Ok("No matches found.".to_string());
}
Ok(super::truncate_output(results.join("\n"), 30000))
}

View file

@ -5,6 +5,7 @@
// working_stack) and delegates everything else to thought::dispatch.
// Core tools
mod ast_grep;
mod bash;
pub mod channels;
mod edit;
@ -160,7 +161,7 @@ pub fn tools() -> Vec<Tool> {
let mut all = vec![
read::tool(), write::tool(), edit::tool(),
grep::tool(), glob::tool(), bash::tool(),
vision::tool(),
ast_grep::tool(), vision::tool(),
];
all.extend(web::tools());
all.extend(memory::memory_tools());