AST-level pattern matching — find code by structure, not text.
e.g. find all `if let Some($X) = $Y { $$$BODY }` patterns.
Supports C, Rust, Python, JS/TS, Go, and 20+ languages.
Gracefully errors if sg binary isn't installed.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
146 lines
5.1 KiB
Rust
146 lines
5.1 KiB
Rust
// tools/ast_grep.rs — Structural code search using ast-grep library
|
|
//
|
|
// AST-level pattern matching: find code structures, not just text.
|
|
// Uses ast-grep-core and ast-grep-language directly — no shell subprocess.
|
|
|
|
use std::sync::Arc;
|
|
use std::path::Path;
|
|
use anyhow::{Context, Result};
|
|
use serde::Deserialize;
|
|
|
|
use ast_grep_core::Pattern;
|
|
use ast_grep_language::{SupportLang, LanguageExt};
|
|
|
|
#[derive(Deserialize)]
|
|
struct Args {
|
|
pattern: String,
|
|
#[serde(default = "default_path")]
|
|
path: String,
|
|
lang: Option<String>,
|
|
}
|
|
|
|
fn default_path() -> String { ".".into() }
|
|
|
|
pub fn tool() -> super::Tool {
|
|
super::Tool {
|
|
name: "ast_grep",
|
|
description: "Structural code search using AST patterns. Finds code by structure, not text — \
|
|
e.g. find all `if let Some($X) = $Y { $$$BODY }` patterns. \
|
|
Supports C, Rust, Python, JS/TS, Go, Java, and 20+ languages.",
|
|
parameters_json: r#"{"type":"object","properties":{"pattern":{"type":"string","description":"AST pattern to search for. Use $X for single node wildcards, $$$X for multiple nodes."},"path":{"type":"string","description":"Directory or file to search in (default: current directory)"},"lang":{"type":"string","description":"Language (e.g. 'rust', 'c', 'python', 'javascript'). Auto-detected from file extension if omitted."}},"required":["pattern"]}"#,
|
|
handler: Arc::new(|_a, v| Box::pin(async move { ast_grep_search(&v) })),
|
|
}
|
|
}
|
|
|
|
fn detect_lang(path: &Path) -> Option<SupportLang> {
|
|
let ext = path.extension()?.to_str()?;
|
|
parse_lang(ext)
|
|
}
|
|
|
|
fn parse_lang(name: &str) -> Option<SupportLang> {
|
|
// ast-grep-language provides from_extension but we want from name
|
|
match name.to_lowercase().as_str() {
|
|
"rust" | "rs" => Some(SupportLang::Rust),
|
|
"c" => Some(SupportLang::C),
|
|
"cpp" | "c++" | "cc" | "cxx" => Some(SupportLang::Cpp),
|
|
"python" | "py" => Some(SupportLang::Python),
|
|
"javascript" | "js" => Some(SupportLang::JavaScript),
|
|
"typescript" | "ts" => Some(SupportLang::TypeScript),
|
|
"go" => Some(SupportLang::Go),
|
|
"java" => Some(SupportLang::Java),
|
|
"json" => Some(SupportLang::Json),
|
|
"html" => Some(SupportLang::Html),
|
|
"css" => Some(SupportLang::Css),
|
|
"bash" | "sh" => Some(SupportLang::Bash),
|
|
"ruby" | "rb" => Some(SupportLang::Ruby),
|
|
"yaml" | "yml" => Some(SupportLang::Yaml),
|
|
"lua" => Some(SupportLang::Lua),
|
|
"kotlin" | "kt" => Some(SupportLang::Kotlin),
|
|
"swift" => Some(SupportLang::Swift),
|
|
"scala" => Some(SupportLang::Scala),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn search_file(
|
|
path: &Path,
|
|
lang: SupportLang,
|
|
pattern: &Pattern,
|
|
results: &mut Vec<String>,
|
|
) -> Result<()> {
|
|
let source = std::fs::read_to_string(path)
|
|
.with_context(|| format!("reading {}", path.display()))?;
|
|
let tree = lang.ast_grep(&source);
|
|
for node_match in tree.root().find_all(pattern) {
|
|
let start = node_match.start_pos();
|
|
let line = start.line() + 1;
|
|
let matched_text = node_match.text();
|
|
let preview = if matched_text.len() > 200 {
|
|
format!("{}...", &matched_text[..200])
|
|
} else {
|
|
matched_text.to_string()
|
|
};
|
|
results.push(format!("{}:{}: {}", path.display(), line, preview));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn walk_and_search(
|
|
dir: &Path,
|
|
explicit_lang: Option<SupportLang>,
|
|
pattern_str: &str,
|
|
results: &mut Vec<String>,
|
|
) -> Result<()> {
|
|
if dir.is_file() {
|
|
let lang = explicit_lang
|
|
.or_else(|| detect_lang(dir))
|
|
.ok_or_else(|| anyhow::anyhow!("cannot detect language for {}", dir.display()))?;
|
|
let pattern = Pattern::new(pattern_str, lang);
|
|
return search_file(dir, lang, &pattern, results);
|
|
}
|
|
|
|
for entry in walkdir::WalkDir::new(dir)
|
|
.into_iter()
|
|
.filter_entry(|e| {
|
|
let name = e.file_name().to_str().unwrap_or("");
|
|
!name.starts_with('.') && name != "target" && name != "node_modules"
|
|
})
|
|
{
|
|
let entry = match entry {
|
|
Ok(e) => e,
|
|
Err(_) => continue,
|
|
};
|
|
if !entry.file_type().is_file() { continue; }
|
|
|
|
let path = entry.path();
|
|
let lang = match explicit_lang.or_else(|| detect_lang(path)) {
|
|
Some(l) => l,
|
|
None => continue,
|
|
};
|
|
let pattern = Pattern::new(pattern_str, lang);
|
|
let _ = search_file(path, lang, &pattern, results);
|
|
|
|
if results.len() >= 100 {
|
|
results.push("... (truncated at 100 matches)".into());
|
|
break;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn ast_grep_search(args: &serde_json::Value) -> Result<String> {
|
|
let a: Args = serde_json::from_value(args.clone())
|
|
.context("invalid ast_grep arguments")?;
|
|
|
|
let explicit_lang = a.lang.as_deref().and_then(parse_lang);
|
|
let path = Path::new(&a.path);
|
|
|
|
let mut results = Vec::new();
|
|
walk_and_search(path, explicit_lang, &a.pattern, &mut results)?;
|
|
|
|
if results.is_empty() {
|
|
return Ok("No matches found.".to_string());
|
|
}
|
|
|
|
Ok(super::truncate_output(results.join("\n"), 30000))
|
|
}
|