From 74f05924ff47c1ce9cc34b5b8dfd37d296cb4d07 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Sat, 21 Mar 2026 16:28:10 -0400 Subject: [PATCH] refactor: use typed Deserialize structs for tool arguments Convert read_file, write_file, edit_file, and glob from manual args["key"].as_str() parsing to serde_json::from_value with typed Args structs. Gives type safety, default values via serde attributes, and clearer error messages on missing/wrong-type arguments. Co-Authored-By: Claude Opus 4.6 (1M context) --- poc-agent/src/tools/edit.rs | 62 ++++++++++++++++---------------- poc-agent/src/tools/glob_tool.rs | 21 +++++++---- poc-agent/src/tools/read.rs | 27 +++++++++----- poc-agent/src/tools/write.rs | 24 +++++++------ 4 files changed, 77 insertions(+), 57 deletions(-) diff --git a/poc-agent/src/tools/edit.rs b/poc-agent/src/tools/edit.rs index 15f0f9e..d1db659 100644 --- a/poc-agent/src/tools/edit.rs +++ b/poc-agent/src/tools/edit.rs @@ -8,10 +8,20 @@ // Supports replace_all for bulk renaming (e.g. variable renames). use anyhow::{Context, Result}; +use serde::Deserialize; use serde_json::json; use crate::types::ToolDef; +#[derive(Deserialize)] +struct Args { + file_path: String, + old_string: String, + new_string: String, + #[serde(default)] + replace_all: bool, +} + pub fn definition() -> ToolDef { ToolDef::new( "edit_file", @@ -44,49 +54,37 @@ pub fn definition() -> ToolDef { } pub fn edit_file(args: &serde_json::Value) -> Result { - let path = args["file_path"] - .as_str() - .context("file_path is required")?; - let old_string = args["old_string"] - .as_str() - .context("old_string is required")?; - let new_string = args["new_string"] - .as_str() - .context("new_string is required")?; - let replace_all = args["replace_all"].as_bool().unwrap_or(false); + let a: Args = serde_json::from_value(args.clone()) + .context("invalid edit_file arguments")?; - if old_string == new_string { + if a.old_string == a.new_string { anyhow::bail!("old_string and new_string are identical"); } - let content = - std::fs::read_to_string(path).with_context(|| format!("Failed to read {}", path))?; + let content = std::fs::read_to_string(&a.file_path) + .with_context(|| format!("Failed to read {}", a.file_path))?; - if replace_all { - let count = content.matches(old_string).count(); - if count == 0 { - anyhow::bail!("old_string not found in {}", path); - } - let new_content = content.replace(old_string, new_string); - std::fs::write(path, &new_content) - .with_context(|| format!("Failed to write {}", path))?; - Ok(format!("Replaced {} occurrences in {}", count, path)) + let count = content.matches(&*a.old_string).count(); + if count == 0 { + anyhow::bail!("old_string not found in {}", a.file_path); + } + + if a.replace_all { + let new_content = content.replace(&*a.old_string, &a.new_string); + std::fs::write(&a.file_path, &new_content) + .with_context(|| format!("Failed to write {}", a.file_path))?; + Ok(format!("Replaced {} occurrences in {}", count, a.file_path)) } else { - let count = content.matches(old_string).count(); - if count == 0 { - anyhow::bail!("old_string not found in {}", path); - } if count > 1 { anyhow::bail!( "old_string appears {} times in {} — use replace_all or provide more context \ to make it unique", - count, - path + count, a.file_path ); } - let new_content = content.replacen(old_string, new_string, 1); - std::fs::write(path, &new_content) - .with_context(|| format!("Failed to write {}", path))?; - Ok(format!("Edited {}", path)) + let new_content = content.replacen(&*a.old_string, &a.new_string, 1); + std::fs::write(&a.file_path, &new_content) + .with_context(|| format!("Failed to write {}", a.file_path))?; + Ok(format!("Edited {}", a.file_path)) } } diff --git a/poc-agent/src/tools/glob_tool.rs b/poc-agent/src/tools/glob_tool.rs index 32ccb6f..5ab1503 100644 --- a/poc-agent/src/tools/glob_tool.rs +++ b/poc-agent/src/tools/glob_tool.rs @@ -5,11 +5,21 @@ // what you want when exploring a codebase. use anyhow::{Context, Result}; +use serde::Deserialize; use serde_json::json; use std::path::PathBuf; use crate::types::ToolDef; +#[derive(Deserialize)] +struct Args { + pattern: String, + #[serde(default = "default_path")] + path: String, +} + +fn default_path() -> String { ".".into() } + pub fn definition() -> ToolDef { ToolDef::new( "glob", @@ -34,14 +44,13 @@ pub fn definition() -> ToolDef { } pub fn glob_search(args: &serde_json::Value) -> Result { - let pattern = args["pattern"].as_str().context("pattern is required")?; - let base = args["path"].as_str().unwrap_or("."); + let a: Args = serde_json::from_value(args.clone()) + .context("invalid glob arguments")?; - // Build the full pattern - let full_pattern = if pattern.starts_with('/') { - pattern.to_string() + let full_pattern = if a.pattern.starts_with('/') { + a.pattern.clone() } else { - format!("{}/{}", base, pattern) + format!("{}/{}", a.path, a.pattern) }; let mut entries: Vec<(PathBuf, std::time::SystemTime)> = Vec::new(); diff --git a/poc-agent/src/tools/read.rs b/poc-agent/src/tools/read.rs index 57c9418..d454c95 100644 --- a/poc-agent/src/tools/read.rs +++ b/poc-agent/src/tools/read.rs @@ -1,10 +1,21 @@ // tools/read.rs — Read file contents use anyhow::{Context, Result}; +use serde::Deserialize; use serde_json::json; use crate::types::ToolDef; +#[derive(Deserialize)] +struct Args { + file_path: String, + #[serde(default = "default_offset")] + offset: usize, + limit: Option, +} + +fn default_offset() -> usize { 1 } + pub fn definition() -> ToolDef { ToolDef::new( "read_file", @@ -31,21 +42,19 @@ pub fn definition() -> ToolDef { } pub fn read_file(args: &serde_json::Value) -> Result { - let path = args["file_path"] - .as_str() - .context("file_path is required")?; + let args: Args = serde_json::from_value(args.clone()) + .context("invalid read_file arguments")?; - let content = - std::fs::read_to_string(path).with_context(|| format!("Failed to read {}", path))?; + let content = std::fs::read_to_string(&args.file_path) + .with_context(|| format!("Failed to read {}", args.file_path))?; let lines: Vec<&str> = content.lines().collect(); - let offset = args["offset"].as_u64().unwrap_or(1).max(1) as usize - 1; - let limit = args["limit"].as_u64().unwrap_or(lines.len() as u64) as usize; + let offset = args.offset.max(1) - 1; + let limit = args.limit.unwrap_or(lines.len()); let mut output = String::new(); for (i, line) in lines.iter().skip(offset).take(limit).enumerate() { - let line_num = offset + i + 1; - output.push_str(&format!("{:>6}\t{}\n", line_num, line)); + output.push_str(&format!("{:>6}\t{}\n", offset + i + 1, line)); } if output.is_empty() { diff --git a/poc-agent/src/tools/write.rs b/poc-agent/src/tools/write.rs index 06135b3..b244b05 100644 --- a/poc-agent/src/tools/write.rs +++ b/poc-agent/src/tools/write.rs @@ -1,11 +1,18 @@ // tools/write.rs — Write file contents use anyhow::{Context, Result}; +use serde::Deserialize; use serde_json::json; use std::path::Path; use crate::types::ToolDef; +#[derive(Deserialize)] +struct Args { + file_path: String, + content: String, +} + pub fn definition() -> ToolDef { ToolDef::new( "write_file", @@ -29,19 +36,16 @@ pub fn definition() -> ToolDef { } pub fn write_file(args: &serde_json::Value) -> Result { - let path = args["file_path"] - .as_str() - .context("file_path is required")?; - let content = args["content"].as_str().context("content is required")?; + let args: Args = serde_json::from_value(args.clone()) + .context("invalid write_file arguments")?; - // Create parent directories if needed - if let Some(parent) = Path::new(path).parent() { + if let Some(parent) = Path::new(&args.file_path).parent() { std::fs::create_dir_all(parent) - .with_context(|| format!("Failed to create directories for {}", path))?; + .with_context(|| format!("Failed to create directories for {}", args.file_path))?; } - std::fs::write(path, content).with_context(|| format!("Failed to write {}", path))?; + std::fs::write(&args.file_path, &args.content) + .with_context(|| format!("Failed to write {}", args.file_path))?; - let line_count = content.lines().count(); - Ok(format!("Wrote {} lines to {}", line_count, path)) + Ok(format!("Wrote {} lines to {}", args.content.lines().count(), args.file_path)) }