From dc1049f62d1bf826242b3c6bd2d7fe1941c5cf86 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Mon, 13 Apr 2026 11:23:52 -0400 Subject: [PATCH] CLI: async runtime + proper RPC fallback plumbing - main.rs: use #[tokio::main] so CLI has a runtime available - memory.rs: make run_with_local_store async (no more runtime creation) - mcp_server.rs: cache socket connection in OnceLock, use block_in_place for async fallback when socket unavailable Fixes "cannot start a runtime from within a runtime" panic when CLI falls back to local store. Co-Authored-By: Kent Overstreet --- src/agent/tools/memory.rs | 18 ++---- src/main.rs | 3 +- src/mcp_server.rs | 130 ++++++++++++++++++++++++++------------ 3 files changed, 96 insertions(+), 55 deletions(-) diff --git a/src/agent/tools/memory.rs b/src/agent/tools/memory.rs index 358a6dd..65a5ad9 100644 --- a/src/agent/tools/memory.rs +++ b/src/agent/tools/memory.rs @@ -58,22 +58,14 @@ async fn cached_store() -> Result>> { } /// Run a tool with a temporarily-opened store (for rpc_local fallback). -pub fn run_with_local_store(tool_name: &str, args: serde_json::Value) -> Result { - let store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?; - let arc = Arc::new(crate::Mutex::new(store)); +pub async fn run_with_local_store(tool_name: &str, args: serde_json::Value) -> Result { + let store = Store::cached().await.map_err(|e| anyhow::anyhow!("{}", e))?; - LOCAL_STORE.with(|s| *s.borrow_mut() = Some(arc)); - let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - let name = tool_name.to_string(); - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap() - .block_on(dispatch(&name, &None, args)) - })); + LOCAL_STORE.with(|s| *s.borrow_mut() = Some(store)); + let result = dispatch(tool_name, &None, args).await; LOCAL_STORE.with(|s| *s.borrow_mut() = None); - result.map_err(|_| anyhow::anyhow!("tool panicked"))? + result } /// Get provenance from agent, or from args._provenance, or "manual". diff --git a/src/main.rs b/src/main.rs index 1a39fdc..990a62b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -476,7 +476,8 @@ impl Run for AdminCmd { } } -fn main() { +#[tokio::main] +async fn main() { std::panic::set_backtrace_style(std::panic::BacktraceStyle::Short); // Handle --help ourselves for expanded subcommand display diff --git a/src/mcp_server.rs b/src/mcp_server.rs index 816fec2..935e211 100644 --- a/src/mcp_server.rs +++ b/src/mcp_server.rs @@ -9,7 +9,7 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use serde_json::json; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{UnixListener, UnixStream}; @@ -21,54 +21,102 @@ pub fn socket_path() -> PathBuf { .join(".consciousness/mcp.sock") } +// Cached socket connection +static SOCKET_CONN: OnceLock>> = OnceLock::new(); + +struct SocketConn { + reader: std::io::BufReader, + writer: std::io::BufWriter, + next_id: u64, +} + +impl SocketConn { + fn connect() -> Result { + use std::os::unix::net::UnixStream; + use std::io::{BufRead, BufReader, BufWriter, Write}; + + let path = socket_path(); + let stream = UnixStream::connect(&path)?; + let mut reader = BufReader::new(stream.try_clone()?); + let mut writer = BufWriter::new(stream); + + // Initialize + let init = json!({"jsonrpc": "2.0", "id": 1, "method": "initialize", + "params": {"protocolVersion": "2024-11-05", "capabilities": {}, + "clientInfo": {"name": "forward", "version": "0.1"}}}); + writeln!(writer, "{}", init)?; + writer.flush()?; + let mut buf = String::new(); + reader.read_line(&mut buf)?; + + Ok(Self { reader, writer, next_id: 1 }) + } + + fn call(&mut self, tool_name: &str, args: &serde_json::Value) -> Result { + use std::io::{BufRead, Write}; + + self.next_id += 1; + let call = json!({"jsonrpc": "2.0", "id": self.next_id, "method": "tools/call", + "params": {"name": tool_name, "arguments": args}}); + writeln!(self.writer, "{}", call)?; + self.writer.flush()?; + + let mut buf = String::new(); + self.reader.read_line(&mut buf)?; + + let resp: serde_json::Value = serde_json::from_str(&buf)?; + if let Some(err) = resp.get("error") { + anyhow::bail!("daemon error: {}", err); + } + let result = resp.get("result").cloned().unwrap_or(json!({})); + let text = result.get("content") + .and_then(|c| c.as_array()) + .and_then(|arr| arr.first()) + .and_then(|c| c.get("text")) + .and_then(|t| t.as_str()) + .unwrap_or(""); + Ok(text.to_string()) + } +} + /// Forward a tool call to the daemon socket, or execute locally if daemon is down. /// Used by external processes that don't have direct store access. pub fn memory_rpc(tool_name: &str, args: serde_json::Value) -> Result { - use std::os::unix::net::UnixStream; - use std::io::{BufRead, BufReader, BufWriter, Write}; + let conn_lock = SOCKET_CONN.get_or_init(|| Mutex::new(None)); + let mut guard = conn_lock.lock().unwrap(); - let path = socket_path(); - let stream = match UnixStream::connect(&path) { - Ok(s) => s, - Err(_) => return rpc_local(tool_name, &args), - }; - let mut reader = BufReader::new(stream.try_clone()?); - let mut writer = BufWriter::new(stream); - - // Initialize - let init = json!({"jsonrpc": "2.0", "id": 1, "method": "initialize", - "params": {"protocolVersion": "2024-11-05", "capabilities": {}, - "clientInfo": {"name": "forward", "version": "0.1"}}}); - writeln!(writer, "{}", init)?; - writer.flush()?; - let mut buf = String::new(); - reader.read_line(&mut buf)?; - - // Call tool - let call = json!({"jsonrpc": "2.0", "id": 2, "method": "tools/call", - "params": {"name": tool_name, "arguments": args}}); - writeln!(writer, "{}", call)?; - writer.flush()?; - buf.clear(); - reader.read_line(&mut buf)?; - - let resp: serde_json::Value = serde_json::from_str(&buf)?; - if let Some(err) = resp.get("error") { - anyhow::bail!("daemon error: {}", err); + // Try cached connection first + if let Some(conn) = guard.as_mut() { + match conn.call(tool_name, &args) { + Ok(result) => return Ok(result), + Err(_) => { + // Connection broken, clear cache and retry + *guard = None; + } + } + } + + // Try to establish new connection + match SocketConn::connect() { + Ok(mut conn) => { + let result = conn.call(tool_name, &args); + *guard = Some(conn); + result + } + Err(_) => { + // Socket unavailable - fall back to local store + drop(guard); // Release lock before blocking + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current() + .block_on(rpc_local(tool_name, &args)) + }) + } } - let result = resp.get("result").cloned().unwrap_or(json!({})); - let text = result.get("content") - .and_then(|c| c.as_array()) - .and_then(|arr| arr.first()) - .and_then(|c| c.get("text")) - .and_then(|t| t.as_str()) - .unwrap_or(""); - Ok(text.to_string()) } /// Execute a tool locally when daemon isn't running. -fn rpc_local(tool_name: &str, args: &serde_json::Value) -> Result { - crate::agent::tools::memory::run_with_local_store(tool_name, args.clone()) +async fn rpc_local(tool_name: &str, args: &serde_json::Value) -> Result { + crate::agent::tools::memory::run_with_local_store(tool_name, args.clone()).await } #[derive(Debug, Deserialize)]