consciousness/channels/socat/src/main.rs

321 lines
12 KiB
Rust
Raw Normal View History

// channel-socat — Generic stream channel daemon
//
// Listens on a unix socket for incoming connections. Each connection
// becomes a bidirectional text channel. Also supports outbound
// connections via the open RPC.
//
// Socket: ~/.consciousness/channels/socat.sock (capnp RPC)
// Listen: ~/.consciousness/channels/socat.stream.sock (data)
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::rc::Rc;
use capnp::capability::Promise;
use capnp_rpc::{pry, rpc_twoparty_capnp, twoparty, RpcSystem};
use futures::AsyncReadExt;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UnixListener, UnixStream};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tracing::{error, info, warn};
use poc_memory::channel_capnp::{channel_client, channel_server};
use poc_memory::thalamus::channel_log::ChannelLog;
// ── State ──────────────────────────────────────────────────────
struct ChannelState {
log: ChannelLog,
writer: Option<tokio::sync::mpsc::UnboundedSender<String>>,
}
struct State {
channels: BTreeMap<String, ChannelState>,
subscribers: Vec<channel_client::Client>,
next_id: u32,
}
type SharedState = Rc<RefCell<State>>;
impl State {
fn new() -> Self {
Self {
channels: BTreeMap::new(),
subscribers: Vec::new(),
next_id: 0,
}
}
fn next_channel_key(&mut self, label: &str) -> String {
let key = if self.next_id == 0 {
format!("socat.{}", label)
} else {
format!("socat.{}.{}", label, self.next_id)
};
self.next_id += 1;
key
}
fn push_message(&mut self, channel: &str, line: String, urgency: u8) {
let ch = self.channels
.entry(channel.to_string())
.or_insert_with(|| ChannelState { log: ChannelLog::new(), writer: None });
ch.log.push(line.clone());
let preview: String = line.chars().take(80).collect();
for sub in &self.subscribers {
let mut req = sub.notify_request();
let mut list = req.get().init_notifications(1);
let mut n = list.reborrow().get(0);
n.set_channel(channel);
n.set_urgency(urgency);
n.set_preview(&preview);
n.set_count(1);
tokio::task::spawn_local(async move {
let _ = req.send().promise.await;
});
}
}
}
// ── Stream handler ─────────────────────────────────────────────
async fn handle_stream<R, W>(state: SharedState, channel_key: String, reader: R, mut writer: W)
where
R: tokio::io::AsyncRead + Unpin + 'static,
W: tokio::io::AsyncWrite + Unpin + 'static,
{
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
{
let mut s = state.borrow_mut();
let ch = s.channels
.entry(channel_key.clone())
.or_insert_with(|| ChannelState { log: ChannelLog::new(), writer: None });
ch.writer = Some(tx);
}
info!("channel {} connected", channel_key);
// Writer task
let wk = channel_key.clone();
let write_handle = tokio::task::spawn_local(async move {
while let Some(msg) = rx.recv().await {
if writer.write_all(msg.as_bytes()).await.is_err() { break; }
if !msg.ends_with('\n') {
if writer.write_all(b"\n").await.is_err() { break; }
}
let _ = writer.flush().await;
}
warn!("writer ended for {}", wk);
});
// Read lines
let mut lines = tokio::io::BufReader::new(reader).lines();
while let Ok(Some(line)) = lines.next_line().await {
if line.trim().is_empty() { continue; }
state.borrow_mut().push_message(&channel_key, line, 2);
}
info!("channel {} disconnected", channel_key);
{
let mut s = state.borrow_mut();
if let Some(ch) = s.channels.get_mut(&channel_key) {
ch.writer = None;
}
}
write_handle.abort();
}
// ── Outbound connections ───────────────────────────────────────
async fn connect_outbound(state: SharedState, label: String, addr: String) -> Result<(), String> {
let channel_key = format!("socat.{}", label);
// Already connected?
{
let s = state.borrow();
if let Some(ch) = s.channels.get(&channel_key) {
if ch.writer.is_some() { return Ok(()); }
}
}
if let Some(tcp_addr) = addr.strip_prefix("tcp:") {
let stream = TcpStream::connect(tcp_addr).await
.map_err(|e| format!("tcp connect failed: {e}"))?;
let (r, w) = stream.into_split();
tokio::task::spawn_local(handle_stream(state, channel_key, r, w));
} else if let Some(path) = addr.strip_prefix("unix:") {
let stream = UnixStream::connect(path).await
.map_err(|e| format!("unix connect failed: {e}"))?;
let (r, w) = stream.into_split();
tokio::task::spawn_local(handle_stream(state, channel_key, r, w));
} else {
let stream = TcpStream::connect(&addr).await
.map_err(|e| format!("connect failed: {e}"))?;
let (r, w) = stream.into_split();
tokio::task::spawn_local(handle_stream(state, channel_key, r, w));
}
Ok(())
}
// ── ChannelServer ──────────────────────────────────────────────
struct ChannelServerImpl { state: SharedState }
impl channel_server::Server for ChannelServerImpl {
fn recv(
&mut self, params: channel_server::RecvParams, mut results: channel_server::RecvResults,
) -> Promise<(), capnp::Error> {
let params = pry!(params.get());
let channel = pry!(pry!(params.get_channel()).to_str()).to_string();
let all_new = params.get_all_new();
let min_count = params.get_min_count() as usize;
let mut s = self.state.borrow_mut();
let text = s.channels.get_mut(&channel)
.map(|ch| if all_new { ch.log.recv_new(min_count) } else { ch.log.recv_history(min_count) })
.unwrap_or_default();
results.get().set_text(&text);
Promise::ok(())
}
fn send(
&mut self, params: channel_server::SendParams, _results: channel_server::SendResults,
) -> Promise<(), capnp::Error> {
let params = pry!(params.get());
let channel = pry!(pry!(params.get_channel()).to_str()).to_string();
let message = pry!(pry!(params.get_message()).to_str()).to_string();
let mut s = self.state.borrow_mut();
if let Some(ch) = s.channels.get_mut(&channel) {
if let Some(ref tx) = ch.writer {
let _ = tx.send(message.clone());
}
ch.log.push_own(format!("> {}", message));
}
Promise::ok(())
}
fn list(
&mut self, _params: channel_server::ListParams, mut results: channel_server::ListResults,
) -> Promise<(), capnp::Error> {
let s = self.state.borrow();
let channels: Vec<_> = s.channels.iter()
.map(|(name, ch)| (name.clone(), ch.writer.is_some(), ch.log.unread()))
.collect();
let mut list = results.get().init_channels(channels.len() as u32);
for (i, (name, connected, unread)) in channels.iter().enumerate() {
let mut entry = list.reborrow().get(i as u32);
entry.set_name(&name);
entry.set_connected(*connected);
entry.set_unread(*unread as u32);
}
Promise::ok(())
}
fn subscribe(
&mut self, params: channel_server::SubscribeParams, _results: channel_server::SubscribeResults,
) -> Promise<(), capnp::Error> {
let callback = pry!(pry!(params.get()).get_callback());
self.state.borrow_mut().subscribers.push(callback);
Promise::ok(())
}
fn open(
&mut self, params: channel_server::OpenParams, _results: channel_server::OpenResults,
) -> Promise<(), capnp::Error> {
let params = pry!(params.get());
let label = pry!(pry!(params.get_label()).to_str()).to_string();
let state = self.state.clone();
Promise::from_future(async move {
connect_outbound(state, label.clone(), label).await
.map_err(|e| capnp::Error::failed(e))
})
}
fn close(
&mut self, params: channel_server::CloseParams, _results: channel_server::CloseResults,
) -> Promise<(), capnp::Error> {
let params = pry!(params.get());
let channel = pry!(pry!(params.get_channel()).to_str()).to_string();
let mut s = self.state.borrow_mut();
if let Some(ch) = s.channels.get_mut(&channel) {
info!("closing {}", channel);
ch.writer = None;
}
Promise::ok(())
}
}
// ── Main ───────────────────────────────────────────────────────
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let dir = dirs::home_dir()
.unwrap_or_default()
.join(".consciousness/channels");
std::fs::create_dir_all(&dir)?;
let rpc_sock = dir.join("socat.sock");
let stream_sock = dir.join("socat.stream.sock");
let _ = std::fs::remove_file(&rpc_sock);
let _ = std::fs::remove_file(&stream_sock);
info!("socat daemon starting");
info!(" rpc: {}", rpc_sock.display());
info!(" stream: {}", stream_sock.display());
let state = Rc::new(RefCell::new(State::new()));
tokio::task::LocalSet::new()
.run_until(async move {
// Listen for data connections — each becomes a channel
let stream_listener = UnixListener::bind(&stream_sock)?;
let stream_state = state.clone();
tokio::task::spawn_local(async move {
loop {
match stream_listener.accept().await {
Ok((stream, _)) => {
let key = stream_state.borrow_mut().next_channel_key("conn");
info!("incoming connection → {}", key);
let (r, w) = stream.into_split();
let s = stream_state.clone();
tokio::task::spawn_local(handle_stream(s, key, r, w));
}
Err(e) => error!("stream accept error: {}", e),
}
}
});
// Listen for capnp RPC connections
let rpc_listener = UnixListener::bind(&rpc_sock)?;
loop {
let (stream, _) = rpc_listener.accept().await?;
let (reader, writer) = stream.compat().split();
let network = twoparty::VatNetwork::new(
futures::io::BufReader::new(reader),
futures::io::BufWriter::new(writer),
rpc_twoparty_capnp::Side::Server,
Default::default(),
);
let server = ChannelServerImpl { state: state.clone() };
let client: channel_server::Client = capnp_rpc::new_client(server);
tokio::task::spawn_local(
RpcSystem::new(Box::new(network), Some(client.client))
);
}
#[allow(unreachable_code)]
Ok::<(), Box<dyn std::error::Error>>(())
})
.await
}