CLI: async runtime + proper RPC fallback plumbing
- main.rs: use #[tokio::main] so CLI has a runtime available - memory.rs: make run_with_local_store async (no more runtime creation) - mcp_server.rs: cache socket connection in OnceLock, use block_in_place for async fallback when socket unavailable Fixes "cannot start a runtime from within a runtime" panic when CLI falls back to local store. Co-Authored-By: Kent Overstreet <kent.overstreet@linux.dev>
This commit is contained in:
parent
7476e9d0db
commit
dc1049f62d
3 changed files with 96 additions and 55 deletions
|
|
@ -58,22 +58,14 @@ async fn cached_store() -> Result<Arc<crate::Mutex<Store>>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run a tool with a temporarily-opened store (for rpc_local fallback).
|
/// Run a tool with a temporarily-opened store (for rpc_local fallback).
|
||||||
pub fn run_with_local_store(tool_name: &str, args: serde_json::Value) -> Result<String> {
|
pub async fn run_with_local_store(tool_name: &str, args: serde_json::Value) -> Result<String> {
|
||||||
let store = Store::load().map_err(|e| anyhow::anyhow!("{}", e))?;
|
let store = Store::cached().await.map_err(|e| anyhow::anyhow!("{}", e))?;
|
||||||
let arc = Arc::new(crate::Mutex::new(store));
|
|
||||||
|
|
||||||
LOCAL_STORE.with(|s| *s.borrow_mut() = Some(arc));
|
LOCAL_STORE.with(|s| *s.borrow_mut() = Some(store));
|
||||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
let result = dispatch(tool_name, &None, args).await;
|
||||||
let name = tool_name.to_string();
|
|
||||||
tokio::runtime::Builder::new_current_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
.block_on(dispatch(&name, &None, args))
|
|
||||||
}));
|
|
||||||
LOCAL_STORE.with(|s| *s.borrow_mut() = None);
|
LOCAL_STORE.with(|s| *s.borrow_mut() = None);
|
||||||
|
|
||||||
result.map_err(|_| anyhow::anyhow!("tool panicked"))?
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get provenance from agent, or from args._provenance, or "manual".
|
/// Get provenance from agent, or from args._provenance, or "manual".
|
||||||
|
|
|
||||||
|
|
@ -476,7 +476,8 @@ impl Run for AdminCmd {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
std::panic::set_backtrace_style(std::panic::BacktraceStyle::Short);
|
std::panic::set_backtrace_style(std::panic::BacktraceStyle::Short);
|
||||||
|
|
||||||
// Handle --help ourselves for expanded subcommand display
|
// Handle --help ourselves for expanded subcommand display
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ use anyhow::{Context, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex, OnceLock};
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||||||
use tokio::net::{UnixListener, UnixStream};
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
|
|
||||||
|
|
@ -21,54 +21,102 @@ pub fn socket_path() -> PathBuf {
|
||||||
.join(".consciousness/mcp.sock")
|
.join(".consciousness/mcp.sock")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cached socket connection
|
||||||
|
static SOCKET_CONN: OnceLock<Mutex<Option<SocketConn>>> = OnceLock::new();
|
||||||
|
|
||||||
|
struct SocketConn {
|
||||||
|
reader: std::io::BufReader<std::os::unix::net::UnixStream>,
|
||||||
|
writer: std::io::BufWriter<std::os::unix::net::UnixStream>,
|
||||||
|
next_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SocketConn {
|
||||||
|
fn connect() -> Result<Self> {
|
||||||
|
use std::os::unix::net::UnixStream;
|
||||||
|
use std::io::{BufRead, BufReader, BufWriter, Write};
|
||||||
|
|
||||||
|
let path = socket_path();
|
||||||
|
let stream = UnixStream::connect(&path)?;
|
||||||
|
let mut reader = BufReader::new(stream.try_clone()?);
|
||||||
|
let mut writer = BufWriter::new(stream);
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
let init = json!({"jsonrpc": "2.0", "id": 1, "method": "initialize",
|
||||||
|
"params": {"protocolVersion": "2024-11-05", "capabilities": {},
|
||||||
|
"clientInfo": {"name": "forward", "version": "0.1"}}});
|
||||||
|
writeln!(writer, "{}", init)?;
|
||||||
|
writer.flush()?;
|
||||||
|
let mut buf = String::new();
|
||||||
|
reader.read_line(&mut buf)?;
|
||||||
|
|
||||||
|
Ok(Self { reader, writer, next_id: 1 })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, tool_name: &str, args: &serde_json::Value) -> Result<String> {
|
||||||
|
use std::io::{BufRead, Write};
|
||||||
|
|
||||||
|
self.next_id += 1;
|
||||||
|
let call = json!({"jsonrpc": "2.0", "id": self.next_id, "method": "tools/call",
|
||||||
|
"params": {"name": tool_name, "arguments": args}});
|
||||||
|
writeln!(self.writer, "{}", call)?;
|
||||||
|
self.writer.flush()?;
|
||||||
|
|
||||||
|
let mut buf = String::new();
|
||||||
|
self.reader.read_line(&mut buf)?;
|
||||||
|
|
||||||
|
let resp: serde_json::Value = serde_json::from_str(&buf)?;
|
||||||
|
if let Some(err) = resp.get("error") {
|
||||||
|
anyhow::bail!("daemon error: {}", err);
|
||||||
|
}
|
||||||
|
let result = resp.get("result").cloned().unwrap_or(json!({}));
|
||||||
|
let text = result.get("content")
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
.and_then(|arr| arr.first())
|
||||||
|
.and_then(|c| c.get("text"))
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
Ok(text.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Forward a tool call to the daemon socket, or execute locally if daemon is down.
|
/// Forward a tool call to the daemon socket, or execute locally if daemon is down.
|
||||||
/// Used by external processes that don't have direct store access.
|
/// Used by external processes that don't have direct store access.
|
||||||
pub fn memory_rpc(tool_name: &str, args: serde_json::Value) -> Result<String> {
|
pub fn memory_rpc(tool_name: &str, args: serde_json::Value) -> Result<String> {
|
||||||
use std::os::unix::net::UnixStream;
|
let conn_lock = SOCKET_CONN.get_or_init(|| Mutex::new(None));
|
||||||
use std::io::{BufRead, BufReader, BufWriter, Write};
|
let mut guard = conn_lock.lock().unwrap();
|
||||||
|
|
||||||
let path = socket_path();
|
// Try cached connection first
|
||||||
let stream = match UnixStream::connect(&path) {
|
if let Some(conn) = guard.as_mut() {
|
||||||
Ok(s) => s,
|
match conn.call(tool_name, &args) {
|
||||||
Err(_) => return rpc_local(tool_name, &args),
|
Ok(result) => return Ok(result),
|
||||||
};
|
Err(_) => {
|
||||||
let mut reader = BufReader::new(stream.try_clone()?);
|
// Connection broken, clear cache and retry
|
||||||
let mut writer = BufWriter::new(stream);
|
*guard = None;
|
||||||
|
}
|
||||||
// Initialize
|
}
|
||||||
let init = json!({"jsonrpc": "2.0", "id": 1, "method": "initialize",
|
}
|
||||||
"params": {"protocolVersion": "2024-11-05", "capabilities": {},
|
|
||||||
"clientInfo": {"name": "forward", "version": "0.1"}}});
|
// Try to establish new connection
|
||||||
writeln!(writer, "{}", init)?;
|
match SocketConn::connect() {
|
||||||
writer.flush()?;
|
Ok(mut conn) => {
|
||||||
let mut buf = String::new();
|
let result = conn.call(tool_name, &args);
|
||||||
reader.read_line(&mut buf)?;
|
*guard = Some(conn);
|
||||||
|
result
|
||||||
// Call tool
|
}
|
||||||
let call = json!({"jsonrpc": "2.0", "id": 2, "method": "tools/call",
|
Err(_) => {
|
||||||
"params": {"name": tool_name, "arguments": args}});
|
// Socket unavailable - fall back to local store
|
||||||
writeln!(writer, "{}", call)?;
|
drop(guard); // Release lock before blocking
|
||||||
writer.flush()?;
|
tokio::task::block_in_place(|| {
|
||||||
buf.clear();
|
tokio::runtime::Handle::current()
|
||||||
reader.read_line(&mut buf)?;
|
.block_on(rpc_local(tool_name, &args))
|
||||||
|
})
|
||||||
let resp: serde_json::Value = serde_json::from_str(&buf)?;
|
}
|
||||||
if let Some(err) = resp.get("error") {
|
|
||||||
anyhow::bail!("daemon error: {}", err);
|
|
||||||
}
|
}
|
||||||
let result = resp.get("result").cloned().unwrap_or(json!({}));
|
|
||||||
let text = result.get("content")
|
|
||||||
.and_then(|c| c.as_array())
|
|
||||||
.and_then(|arr| arr.first())
|
|
||||||
.and_then(|c| c.get("text"))
|
|
||||||
.and_then(|t| t.as_str())
|
|
||||||
.unwrap_or("");
|
|
||||||
Ok(text.to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a tool locally when daemon isn't running.
|
/// Execute a tool locally when daemon isn't running.
|
||||||
fn rpc_local(tool_name: &str, args: &serde_json::Value) -> Result<String> {
|
async fn rpc_local(tool_name: &str, args: &serde_json::Value) -> Result<String> {
|
||||||
crate::agent::tools::memory::run_with_local_store(tool_name, args.clone())
|
crate::agent::tools::memory::run_with_local_store(tool_name, args.clone()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue