193 lines
6.7 KiB
Rust
193 lines
6.7 KiB
Rust
|
|
// tools/mcp_client.rs — MCP client for external tool servers
|
||
|
|
//
|
||
|
|
// Spawns external MCP servers, discovers their tools, dispatches calls.
|
||
|
|
// JSON-RPC 2.0 over stdio (newline-delimited). Global registry, lazy
|
||
|
|
// init from config.
|
||
|
|
|
||
|
|
use anyhow::{Context, Result};
|
||
|
|
use serde::Deserialize;
|
||
|
|
use serde_json::json;
|
||
|
|
use std::sync::OnceLock;
|
||
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||
|
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
||
|
|
use tokio::sync::Mutex as TokioMutex;
|
||
|
|
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct McpTool {
|
||
|
|
pub name: String,
|
||
|
|
pub description: String,
|
||
|
|
pub parameters_json: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
struct McpServer {
|
||
|
|
#[allow(dead_code)]
|
||
|
|
name: String,
|
||
|
|
stdin: BufWriter<ChildStdin>,
|
||
|
|
stdout: BufReader<ChildStdout>,
|
||
|
|
_child: Child,
|
||
|
|
next_id: u64,
|
||
|
|
tools: Vec<McpTool>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct JsonRpcResponse {
|
||
|
|
id: Option<u64>,
|
||
|
|
result: Option<serde_json::Value>,
|
||
|
|
error: Option<JsonRpcError>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Deserialize)]
|
||
|
|
struct JsonRpcError {
|
||
|
|
code: i64,
|
||
|
|
message: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl McpServer {
|
||
|
|
async fn request(&mut self, method: &str, params: Option<serde_json::Value>) -> Result<serde_json::Value> {
|
||
|
|
self.next_id += 1;
|
||
|
|
let id = self.next_id;
|
||
|
|
let req = json!({ "jsonrpc": "2.0", "id": id, "method": method, "params": params });
|
||
|
|
let mut line = serde_json::to_string(&req)?;
|
||
|
|
line.push('\n');
|
||
|
|
self.stdin.write_all(line.as_bytes()).await?;
|
||
|
|
self.stdin.flush().await?;
|
||
|
|
|
||
|
|
let mut buf = String::new();
|
||
|
|
loop {
|
||
|
|
buf.clear();
|
||
|
|
let n = self.stdout.read_line(&mut buf).await?;
|
||
|
|
if n == 0 { anyhow::bail!("MCP server closed connection"); }
|
||
|
|
let trimmed = buf.trim();
|
||
|
|
if trimmed.is_empty() { continue; }
|
||
|
|
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
|
||
|
|
if resp.id == Some(id) {
|
||
|
|
if let Some(err) = resp.error {
|
||
|
|
anyhow::bail!("MCP error {}: {}", err.code, err.message);
|
||
|
|
}
|
||
|
|
return Ok(resp.result.unwrap_or(serde_json::Value::Null));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn notify(&mut self, method: &str) -> Result<()> {
|
||
|
|
let msg = json!({ "jsonrpc": "2.0", "method": method });
|
||
|
|
let mut line = serde_json::to_string(&msg)?;
|
||
|
|
line.push('\n');
|
||
|
|
self.stdin.write_all(line.as_bytes()).await?;
|
||
|
|
self.stdin.flush().await?;
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn spawn(name: &str, command: &str, args: &[&str]) -> Result<McpServer> {
|
||
|
|
let mut child = Command::new(command)
|
||
|
|
.args(args)
|
||
|
|
.stdin(std::process::Stdio::piped())
|
||
|
|
.stdout(std::process::Stdio::piped())
|
||
|
|
.stderr(std::process::Stdio::null())
|
||
|
|
.spawn()
|
||
|
|
.with_context(|| format!("spawning MCP server: {} {}", command, args.join(" ")))?;
|
||
|
|
|
||
|
|
let mut server = McpServer {
|
||
|
|
name: name.to_string(),
|
||
|
|
stdin: BufWriter::new(child.stdin.take().unwrap()),
|
||
|
|
stdout: BufReader::new(child.stdout.take().unwrap()),
|
||
|
|
_child: child,
|
||
|
|
next_id: 0,
|
||
|
|
tools: Vec::new(),
|
||
|
|
};
|
||
|
|
|
||
|
|
server.request("initialize", Some(json!({
|
||
|
|
"protocolVersion": "2024-11-05",
|
||
|
|
"capabilities": {},
|
||
|
|
"clientInfo": {"name": "consciousness", "version": "0.1"}
|
||
|
|
}))).await.with_context(|| format!("initializing MCP server {}", name))?;
|
||
|
|
|
||
|
|
server.notify("notifications/initialized").await?;
|
||
|
|
|
||
|
|
let tools_result = server.request("tools/list", None).await
|
||
|
|
.with_context(|| format!("listing tools from MCP server {}", name))?;
|
||
|
|
|
||
|
|
if let Some(tool_list) = tools_result.get("tools").and_then(|t| t.as_array()) {
|
||
|
|
for tool in tool_list {
|
||
|
|
server.tools.push(McpTool {
|
||
|
|
name: tool["name"].as_str().unwrap_or("").to_string(),
|
||
|
|
description: tool["description"].as_str().unwrap_or("").to_string(),
|
||
|
|
parameters_json: tool.get("inputSchema")
|
||
|
|
.map(|s| serde_json::to_string(s).unwrap_or_default())
|
||
|
|
.unwrap_or_else(|| r#"{"type":"object"}"#.to_string()),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
dbglog!("[mcp] {} connected: {} tools", name, server.tools.len());
|
||
|
|
Ok(server)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
struct Registry {
|
||
|
|
servers: Vec<McpServer>,
|
||
|
|
}
|
||
|
|
|
||
|
|
static REGISTRY: OnceLock<TokioMutex<Registry>> = OnceLock::new();
|
||
|
|
|
||
|
|
fn registry() -> &'static TokioMutex<Registry> {
|
||
|
|
REGISTRY.get_or_init(|| {
|
||
|
|
let configs = &crate::config::get().mcp_servers;
|
||
|
|
// Can't do async init in OnceLock, so servers are spawned lazily on first access
|
||
|
|
let _ = configs; // configs read but servers spawned in ensure_init()
|
||
|
|
TokioMutex::new(Registry { servers: Vec::new() })
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn ensure_init() -> Result<()> {
|
||
|
|
let mut reg = registry().lock().await;
|
||
|
|
if !reg.servers.is_empty() { return Ok(()); }
|
||
|
|
let configs = crate::config::get().mcp_servers.clone();
|
||
|
|
for cfg in &configs {
|
||
|
|
let args: Vec<&str> = cfg.args.iter().map(|s| s.as_str()).collect();
|
||
|
|
match McpServer::spawn(&cfg.name, &cfg.command, &args).await {
|
||
|
|
Ok(server) => reg.servers.push(server),
|
||
|
|
Err(e) => eprintln!("warning: MCP server {} failed: {:#}", cfg.name, e),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub(super) async fn call_tool(name: &str, args: &serde_json::Value) -> Result<String> {
|
||
|
|
ensure_init().await?;
|
||
|
|
let mut reg = registry().lock().await;
|
||
|
|
let server = reg.servers.iter_mut()
|
||
|
|
.find(|s| s.tools.iter().any(|t| t.name == name))
|
||
|
|
.ok_or_else(|| anyhow::anyhow!("no MCP server has tool {}", name))?;
|
||
|
|
|
||
|
|
let result = server.request("tools/call", Some(json!({
|
||
|
|
"name": name, "arguments": args,
|
||
|
|
}))).await.with_context(|| format!("calling MCP tool {}", name))?;
|
||
|
|
|
||
|
|
if let Some(content) = result.get("content").and_then(|c| c.as_array()) {
|
||
|
|
let texts: Vec<&str> = content.iter()
|
||
|
|
.filter_map(|c| c.get("text").and_then(|t| t.as_str()))
|
||
|
|
.collect();
|
||
|
|
Ok(texts.join("\n"))
|
||
|
|
} else if let Some(text) = result.as_str() {
|
||
|
|
Ok(text.to_string())
|
||
|
|
} else {
|
||
|
|
Ok(serde_json::to_string_pretty(&result)?)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub(super) async fn tool_definitions_json() -> Vec<String> {
|
||
|
|
let _ = ensure_init().await;
|
||
|
|
let reg = registry().lock().await;
|
||
|
|
reg.servers.iter()
|
||
|
|
.flat_map(|s| s.tools.iter())
|
||
|
|
.map(|t| format!(
|
||
|
|
r#"{{"type":"function","function":{{"name":"{}","description":"{}","parameters":{}}}}}"#,
|
||
|
|
t.name,
|
||
|
|
t.description.replace('"', r#"\""#),
|
||
|
|
t.parameters_json,
|
||
|
|
))
|
||
|
|
.collect()
|
||
|
|
}
|