consciousness/src/agent/tools/mcp_client.rs
ProofOfConcept 15f3be27ce Show MCP server failures in the UI instead of debug log
MCP server spawn failures were going to dbglog where the user
wouldn't see them. Route through the agent's notify so they appear
on the status bar.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
2026-04-09 22:46:48 -04:00

202 lines
7 KiB
Rust

// tools/mcp_client.rs — MCP client for external tool servers
//
// Spawns external MCP servers, discovers their tools, dispatches calls.
// JSON-RPC 2.0 over stdio (newline-delimited). Global registry, lazy
// init from config.
use anyhow::{Context, Result};
use serde::Deserialize;
use serde_json::json;
use std::sync::OnceLock;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::Mutex as TokioMutex;
#[derive(Debug, Clone)]
pub struct McpTool {
pub name: String,
pub description: String,
pub parameters_json: String,
}
struct McpServer {
#[allow(dead_code)]
name: String,
stdin: BufWriter<ChildStdin>,
stdout: BufReader<ChildStdout>,
_child: Child,
next_id: u64,
tools: Vec<McpTool>,
}
#[derive(Debug, Deserialize)]
struct JsonRpcResponse {
id: Option<u64>,
result: Option<serde_json::Value>,
error: Option<JsonRpcError>,
}
#[derive(Debug, Deserialize)]
struct JsonRpcError {
code: i64,
message: String,
}
impl McpServer {
async fn request(&mut self, method: &str, params: Option<serde_json::Value>) -> Result<serde_json::Value> {
self.next_id += 1;
let id = self.next_id;
let req = json!({ "jsonrpc": "2.0", "id": id, "method": method, "params": params });
let mut line = serde_json::to_string(&req)?;
line.push('\n');
self.stdin.write_all(line.as_bytes()).await?;
self.stdin.flush().await?;
let mut buf = String::new();
loop {
buf.clear();
let n = self.stdout.read_line(&mut buf).await?;
if n == 0 { anyhow::bail!("MCP server closed connection"); }
let trimmed = buf.trim();
if trimmed.is_empty() { continue; }
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
if resp.id == Some(id) {
if let Some(err) = resp.error {
anyhow::bail!("MCP error {}: {}", err.code, err.message);
}
return Ok(resp.result.unwrap_or(serde_json::Value::Null));
}
}
}
}
async fn notify(&mut self, method: &str) -> Result<()> {
let msg = json!({ "jsonrpc": "2.0", "method": method });
let mut line = serde_json::to_string(&msg)?;
line.push('\n');
self.stdin.write_all(line.as_bytes()).await?;
self.stdin.flush().await?;
Ok(())
}
async fn spawn(name: &str, command: &str, args: &[&str]) -> Result<McpServer> {
let mut child = Command::new(command)
.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.spawn()
.with_context(|| format!("spawning MCP server: {} {}", command, args.join(" ")))?;
let mut server = McpServer {
name: name.to_string(),
stdin: BufWriter::new(child.stdin.take().unwrap()),
stdout: BufReader::new(child.stdout.take().unwrap()),
_child: child,
next_id: 0,
tools: Vec::new(),
};
server.request("initialize", Some(json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "consciousness", "version": "0.1"}
}))).await.with_context(|| format!("initializing MCP server {}", name))?;
server.notify("notifications/initialized").await?;
let tools_result = server.request("tools/list", None).await
.with_context(|| format!("listing tools from MCP server {}", name))?;
if let Some(tool_list) = tools_result.get("tools").and_then(|t| t.as_array()) {
for tool in tool_list {
server.tools.push(McpTool {
name: tool["name"].as_str().unwrap_or("").to_string(),
description: tool["description"].as_str().unwrap_or("").to_string(),
parameters_json: tool.get("inputSchema")
.map(|s| serde_json::to_string(s).unwrap_or_default())
.unwrap_or_else(|| r#"{"type":"object"}"#.to_string()),
});
}
}
dbglog!("[mcp] {} connected: {} tools", name, server.tools.len());
Ok(server)
}
}
struct Registry {
servers: Vec<McpServer>,
}
static REGISTRY: OnceLock<TokioMutex<Registry>> = OnceLock::new();
fn registry() -> &'static TokioMutex<Registry> {
REGISTRY.get_or_init(|| {
let configs = &crate::config::get().mcp_servers;
// Can't do async init in OnceLock, so servers are spawned lazily on first access
let _ = configs; // configs read but servers spawned in ensure_init()
TokioMutex::new(Registry { servers: Vec::new() })
})
}
async fn ensure_init(agent: Option<&std::sync::Arc<super::super::Agent>>) -> Result<()> {
let mut reg = registry().lock().await;
if !reg.servers.is_empty() { return Ok(()); }
let configs = crate::config::get().mcp_servers.clone();
for cfg in &configs {
let args: Vec<&str> = cfg.args.iter().map(|s| s.as_str()).collect();
match McpServer::spawn(&cfg.name, &cfg.command, &args).await {
Ok(server) => reg.servers.push(server),
Err(e) => {
let msg = format!("MCP server {} failed: {:#}", cfg.name, e);
dbglog!("{}", msg);
if let Some(a) = agent {
if let Ok(mut st) = a.state.try_lock() {
st.notify(msg);
}
}
}
}
}
Ok(())
}
pub(super) async fn call_tool(name: &str, args: &serde_json::Value,
agent: Option<&std::sync::Arc<super::super::Agent>>,
) -> Result<String> {
ensure_init(agent).await?;
let mut reg = registry().lock().await;
let server = reg.servers.iter_mut()
.find(|s| s.tools.iter().any(|t| t.name == name))
.ok_or_else(|| anyhow::anyhow!("no MCP server has tool {}", name))?;
let result = server.request("tools/call", Some(json!({
"name": name, "arguments": args,
}))).await.with_context(|| format!("calling MCP tool {}", name))?;
if let Some(content) = result.get("content").and_then(|c| c.as_array()) {
let texts: Vec<&str> = content.iter()
.filter_map(|c| c.get("text").and_then(|t| t.as_str()))
.collect();
Ok(texts.join("\n"))
} else if let Some(text) = result.as_str() {
Ok(text.to_string())
} else {
Ok(serde_json::to_string_pretty(&result)?)
}
}
pub(super) async fn tool_definitions_json() -> Vec<String> {
let _ = ensure_init(None).await;
let reg = registry().lock().await;
reg.servers.iter()
.flat_map(|s| s.tools.iter())
.map(|t| format!(
r#"{{"type":"function","function":{{"name":"{}","description":"{}","parameters":{}}}}}"#,
t.name,
t.description.replace('"', r#"\""#),
t.parameters_json,
))
.collect()
}