diff --git a/src/agent/mod.rs b/src/agent/mod.rs index e79a71b..8b6f43d 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -138,8 +138,17 @@ pub struct Agent { } /// Mutable agent state — behind its own mutex. +/// Which external MCP tools an agent can access. +#[derive(Clone)] +pub enum McpToolAccess { + None, + All, + Some(Vec), +} + pub struct AgentState { pub tools: Vec, + pub mcp_tools: McpToolAccess, pub last_prompt_tokens: u32, pub reasoning_effort: String, pub temperature: f32, @@ -174,8 +183,7 @@ impl Agent { context.conversation_log = conversation_log; context.push_no_log(Section::System, AstNode::system_msg(&system_prompt)); - let tool_defs: Vec = tools::tools().iter() - .map(|t| t.to_json()).collect(); + let tool_defs = tools::all_tool_definitions().await; if !tool_defs.is_empty() { let tools_text = format!( "# Tools\n\nYou have access to the following functions:\n\n\n{}\n\n\n\ @@ -202,6 +210,7 @@ impl Agent { context: tokio::sync::Mutex::new(context), state: tokio::sync::Mutex::new(AgentState { tools: tools::tools(), + mcp_tools: McpToolAccess::All, last_prompt_tokens: 0, reasoning_effort: "none".to_string(), temperature: 0.6, @@ -237,6 +246,7 @@ impl Agent { context: tokio::sync::Mutex::new(ctx), state: tokio::sync::Mutex::new(AgentState { tools, + mcp_tools: McpToolAccess::None, last_prompt_tokens: 0, reasoning_effort: "none".to_string(), temperature: st.temperature, diff --git a/src/agent/tools/mcp_client.rs b/src/agent/tools/mcp_client.rs new file mode 100644 index 0000000..a7348ec --- /dev/null +++ b/src/agent/tools/mcp_client.rs @@ -0,0 +1,192 @@ +// tools/mcp_client.rs — MCP client for external tool servers +// +// Spawns external MCP servers, discovers their tools, dispatches calls. +// JSON-RPC 2.0 over stdio (newline-delimited). Global registry, lazy +// init from config. + +use anyhow::{Context, Result}; +use serde::Deserialize; +use serde_json::json; +use std::sync::OnceLock; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::Mutex as TokioMutex; + +#[derive(Debug, Clone)] +pub struct McpTool { + pub name: String, + pub description: String, + pub parameters_json: String, +} + +struct McpServer { + #[allow(dead_code)] + name: String, + stdin: BufWriter, + stdout: BufReader, + _child: Child, + next_id: u64, + tools: Vec, +} + +#[derive(Debug, Deserialize)] +struct JsonRpcResponse { + id: Option, + result: Option, + error: Option, +} + +#[derive(Debug, Deserialize)] +struct JsonRpcError { + code: i64, + message: String, +} + +impl McpServer { + async fn request(&mut self, method: &str, params: Option) -> Result { + self.next_id += 1; + let id = self.next_id; + let req = json!({ "jsonrpc": "2.0", "id": id, "method": method, "params": params }); + let mut line = serde_json::to_string(&req)?; + line.push('\n'); + self.stdin.write_all(line.as_bytes()).await?; + self.stdin.flush().await?; + + let mut buf = String::new(); + loop { + buf.clear(); + let n = self.stdout.read_line(&mut buf).await?; + if n == 0 { anyhow::bail!("MCP server closed connection"); } + let trimmed = buf.trim(); + if trimmed.is_empty() { continue; } + if let Ok(resp) = serde_json::from_str::(trimmed) { + if resp.id == Some(id) { + if let Some(err) = resp.error { + anyhow::bail!("MCP error {}: {}", err.code, err.message); + } + return Ok(resp.result.unwrap_or(serde_json::Value::Null)); + } + } + } + } + + async fn notify(&mut self, method: &str) -> Result<()> { + let msg = json!({ "jsonrpc": "2.0", "method": method }); + let mut line = serde_json::to_string(&msg)?; + line.push('\n'); + self.stdin.write_all(line.as_bytes()).await?; + self.stdin.flush().await?; + Ok(()) + } + + async fn spawn(name: &str, command: &str, args: &[&str]) -> Result { + let mut child = Command::new(command) + .args(args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::null()) + .spawn() + .with_context(|| format!("spawning MCP server: {} {}", command, args.join(" ")))?; + + let mut server = McpServer { + name: name.to_string(), + stdin: BufWriter::new(child.stdin.take().unwrap()), + stdout: BufReader::new(child.stdout.take().unwrap()), + _child: child, + next_id: 0, + tools: Vec::new(), + }; + + server.request("initialize", Some(json!({ + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "consciousness", "version": "0.1"} + }))).await.with_context(|| format!("initializing MCP server {}", name))?; + + server.notify("notifications/initialized").await?; + + let tools_result = server.request("tools/list", None).await + .with_context(|| format!("listing tools from MCP server {}", name))?; + + if let Some(tool_list) = tools_result.get("tools").and_then(|t| t.as_array()) { + for tool in tool_list { + server.tools.push(McpTool { + name: tool["name"].as_str().unwrap_or("").to_string(), + description: tool["description"].as_str().unwrap_or("").to_string(), + parameters_json: tool.get("inputSchema") + .map(|s| serde_json::to_string(s).unwrap_or_default()) + .unwrap_or_else(|| r#"{"type":"object"}"#.to_string()), + }); + } + } + + dbglog!("[mcp] {} connected: {} tools", name, server.tools.len()); + Ok(server) + } +} + +struct Registry { + servers: Vec, +} + +static REGISTRY: OnceLock> = OnceLock::new(); + +fn registry() -> &'static TokioMutex { + REGISTRY.get_or_init(|| { + let configs = &crate::config::get().mcp_servers; + // Can't do async init in OnceLock, so servers are spawned lazily on first access + let _ = configs; // configs read but servers spawned in ensure_init() + TokioMutex::new(Registry { servers: Vec::new() }) + }) +} + +async fn ensure_init() -> Result<()> { + let mut reg = registry().lock().await; + if !reg.servers.is_empty() { return Ok(()); } + let configs = crate::config::get().mcp_servers.clone(); + for cfg in &configs { + let args: Vec<&str> = cfg.args.iter().map(|s| s.as_str()).collect(); + match McpServer::spawn(&cfg.name, &cfg.command, &args).await { + Ok(server) => reg.servers.push(server), + Err(e) => eprintln!("warning: MCP server {} failed: {:#}", cfg.name, e), + } + } + Ok(()) +} + +pub(super) async fn call_tool(name: &str, args: &serde_json::Value) -> Result { + ensure_init().await?; + let mut reg = registry().lock().await; + let server = reg.servers.iter_mut() + .find(|s| s.tools.iter().any(|t| t.name == name)) + .ok_or_else(|| anyhow::anyhow!("no MCP server has tool {}", name))?; + + let result = server.request("tools/call", Some(json!({ + "name": name, "arguments": args, + }))).await.with_context(|| format!("calling MCP tool {}", name))?; + + if let Some(content) = result.get("content").and_then(|c| c.as_array()) { + let texts: Vec<&str> = content.iter() + .filter_map(|c| c.get("text").and_then(|t| t.as_str())) + .collect(); + Ok(texts.join("\n")) + } else if let Some(text) = result.as_str() { + Ok(text.to_string()) + } else { + Ok(serde_json::to_string_pretty(&result)?) + } +} + +pub(super) async fn tool_definitions_json() -> Vec { + let _ = ensure_init().await; + let reg = registry().lock().await; + reg.servers.iter() + .flat_map(|s| s.tools.iter()) + .map(|t| format!( + r#"{{"type":"function","function":{{"name":"{}","description":"{}","parameters":{}}}}}"#, + t.name, + t.description.replace('"', r#"\""#), + t.parameters_json, + )) + .collect() +} diff --git a/src/agent/tools/mod.rs b/src/agent/tools/mod.rs index bea0167..02b5fe8 100644 --- a/src/agent/tools/mod.rs +++ b/src/agent/tools/mod.rs @@ -6,6 +6,7 @@ // Core tools mod ast_grep; +pub mod mcp_client; mod bash; pub mod channels; mod edit; @@ -152,7 +153,22 @@ pub async fn dispatch_with_agent( match tool { Some(t) => (t.handler)(agent, args.clone()).await .unwrap_or_else(|e| format!("Error: {}", e)), - None => format!("Error: Unknown tool: {}", name), + None => { + let allowed = match &agent { + Some(a) => match &a.state.lock().await.mcp_tools { + super::McpToolAccess::All => true, + super::McpToolAccess::Some(list) => list.iter().any(|t| t == name), + super::McpToolAccess::None => false, + }, + None => true, + }; + if allowed { + if let Ok(result) = mcp_client::call_tool(name, args).await { + return result; + } + } + format!("Error: Unknown tool: {}", name) + } } } @@ -171,6 +187,12 @@ pub fn tools() -> Vec { all } +pub async fn all_tool_definitions() -> Vec { + let mut defs: Vec = tools().iter().map(|t| t.to_json()).collect(); + defs.extend(mcp_client::tool_definitions_json().await); + defs +} + /// Memory + journal tools only — for subconscious agents. pub fn memory_and_journal_tools() -> Vec { let mut all = memory::memory_tools().to_vec(); diff --git a/src/config.rs b/src/config.rs index 9a12f1f..1432547 100644 --- a/src/config.rs +++ b/src/config.rs @@ -107,6 +107,8 @@ pub struct Config { pub scoring_response_window: usize, pub api_reasoning: String, pub agent_types: Vec, + #[serde(default)] + pub mcp_servers: Vec, /// Surface agent timeout in seconds. #[serde(default)] pub surface_timeout_secs: Option, @@ -164,6 +166,7 @@ impl Default for Config { surface_timeout_secs: None, surface_conversation_bytes: None, surface_hooks: vec![], + mcp_servers: vec![], } } } @@ -346,6 +349,16 @@ pub struct AppConfig { pub models: HashMap, #[serde(default = "default_model_name")] pub default_model: String, + #[serde(default)] + pub mcp_servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpServerConfig { + pub name: String, + pub command: String, + #[serde(default)] + pub args: Vec, } #[derive(Debug, Clone, Default, Serialize, Deserialize)] @@ -436,6 +449,7 @@ impl Default for AppConfig { system_prompt_file: None, models: HashMap::new(), default_model: String::new(), + mcp_servers: Vec::new(), } } }