From 46d6d80b42a16e306366ea70169ac3818dfe2474 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Mon, 9 Aug 2021 09:19:00 -0400 Subject: [PATCH] Unix socket listener --- src/bin/syndicate-server.rs | 160 +++++++++++++++++++++++++++--------- src/config.rs | 5 ++ 2 files changed, 127 insertions(+), 38 deletions(-) diff --git a/src/bin/syndicate-server.rs b/src/bin/syndicate-server.rs index 2c6e1e3..21e5a91 100644 --- a/src/bin/syndicate-server.rs +++ b/src/bin/syndicate-server.rs @@ -6,7 +6,9 @@ use preserves::value::NestedValue; use preserves::value::Value; use std::future::ready; +use std::io; use std::iter::FromIterator; +use std::path::PathBuf; use std::sync::Arc; use structopt::StructOpt; // for from_args in main @@ -24,6 +26,8 @@ use syndicate::sturdy; use tokio::net::TcpListener; use tokio::net::TcpStream; +use tokio::net::UnixListener; +use tokio::net::UnixStream; use tungstenite::Message; @@ -32,6 +36,8 @@ async fn main() -> Result<(), Box> { syndicate::convenient_logging()?; syndicate::actor::start_debt_reporter(); + let config = Arc::new(config::ServerConfig::from_args()); + { const BRIGHT_GREEN: &str = "\x1b[92m"; const RED: &str = "\x1b[31m"; @@ -65,8 +71,6 @@ async fn main() -> Result<(), Box> { tracing::info!(r""); } - let config = Arc::new(config::ServerConfig::from_args()); - let mut daemons = Vec::new(); tracing::trace!("startup"); @@ -94,7 +98,16 @@ async fn main() -> Result<(), Box> { daemons.push(Actor::new().boot( syndicate::name!("tcp", port), move |t| Ok(t.state.linked_task(syndicate::name!("listener"), - run_listener(gateway, port, config))))); + run_tcp_listener(gateway, port, config))))); + } + + for path in config.sockets.clone() { + let gateway = Arc::clone(&gateway); + let config = Arc::clone(&config); + daemons.push(Actor::new().boot( + syndicate::name!("unix", socket = debug(path.to_str().expect("representable UnixListener path"))), + move |t| Ok(t.state.linked_task(syndicate::name!("listener"), + run_unix_listener(gateway, path, config))))); } futures::future::join_all(daemons).await; @@ -127,43 +140,23 @@ fn extract_binary_packets( } } -async fn run_connection( +struct ExitListener; + +impl Entity<()> for ExitListener { + fn exit_hook(&mut self, _t: &mut Activation, exit_status: &Arc) -> ActorResult { + tracing::info!(exit_status = debug(exit_status), "disconnect"); + Ok(()) + } +} + +fn run_connection( ac: ActorRef, - stream: TcpStream, + i: relay::Input, + o: relay::Output, gateway: Arc, - addr: std::net::SocketAddr, config: Arc, ) -> ActorResult { - let mut buf = [0; 1]; // peek at the first byte to see what kind of connection to expect - let (i, o) = match stream.peek(&mut buf).await? { - 1 => match buf[0] { - b'G' /* ASCII 'G' for "GET" */ => { - tracing::info!(protocol = display("websocket"), peer = debug(addr)); - let s = tokio_tungstenite::accept_async(stream).await - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - let (o, i) = s.split(); - let i = i.filter_map(|r| ready(extract_binary_packets(r).transpose())); - let o = o.sink_map_err(message_error).with(|bs| ready(Ok(Message::Binary(bs)))); - (relay::Input::Packets(Box::pin(i)), relay::Output::Packets(Box::pin(o))) - }, - _ => { - tracing::info!(protocol = display("raw"), peer = debug(addr)); - let (i, o) = stream.into_split(); - (relay::Input::Bytes(Box::pin(i)), - relay::Output::Bytes(Box::pin(o /* BufWriter::new(o) */))) - } - } - 0 => Err(error("closed before starting", _Any::new(false)))?, - _ => unreachable!() - }; Activation::for_actor(&ac, Debtor::new(syndicate::name!("start-session")), |t| { - struct ExitListener; - impl Entity<()> for ExitListener { - fn exit_hook(&mut self, _t: &mut Activation, exit_status: &Arc) -> ActorResult { - tracing::info!(exit_status = debug(exit_status), "disconnect"); - Ok(()) - } - } let exit_listener = t.state.create(ExitListener); t.state.add_exit_hook(&exit_listener); relay::TunnelRelay::run(t, i, o, Some(gateway), None); @@ -171,7 +164,41 @@ async fn run_connection( }) } -async fn run_listener( +async fn detect_protocol( + ac: ActorRef, + stream: TcpStream, + gateway: Arc, + addr: std::net::SocketAddr, + config: Arc, +) -> ActorResult { + let (i, o) = { + let mut buf = [0; 1]; // peek at the first byte to see what kind of connection to expect + match stream.peek(&mut buf).await? { + 1 => match buf[0] { + b'G' /* ASCII 'G' for "GET" */ => { + tracing::info!(protocol = display("websocket"), peer = debug(addr)); + let s = tokio_tungstenite::accept_async(stream).await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + let (o, i) = s.split(); + let i = i.filter_map(|r| ready(extract_binary_packets(r).transpose())); + let o = o.sink_map_err(message_error).with(|bs| ready(Ok(Message::Binary(bs)))); + (relay::Input::Packets(Box::pin(i)), relay::Output::Packets(Box::pin(o))) + }, + _ => { + tracing::info!(protocol = display("raw"), peer = debug(addr)); + let (i, o) = stream.into_split(); + (relay::Input::Bytes(Box::pin(i)), + relay::Output::Bytes(Box::pin(o /* BufWriter::new(o) */))) + } + } + 0 => Err(error("closed before starting", _Any::new(false)))?, + _ => unreachable!() + } + }; + run_connection(ac, i, o, gateway, config) +} + +async fn run_tcp_listener( gateway: Arc, port: u16, config: Arc, @@ -184,10 +211,67 @@ async fn run_listener( let gateway = Arc::clone(&gateway); let config = Arc::clone(&config); let ac = Actor::new(); - ac.boot(syndicate::name!(parent: None, "connection"), + ac.boot(syndicate::name!(parent: None, "tcp"), move |t| Ok(t.state.linked_task( tracing::Span::current(), - run_connection(t.actor.clone(), stream, gateway, addr, config)))); + detect_protocol(t.actor.clone(), stream, gateway, addr, config)))); + } +} + +async fn run_unix_listener( + gateway: Arc, + path: PathBuf, + config: Arc, +) -> ActorResult { + let path_str = path.to_str().expect("representable UnixListener path"); + tracing::info!("Listening on {:?}", path_str); + let listener = bind_unix_listener(&path).await?; + loop { + let (stream, addr) = listener.accept().await?; + let gateway = Arc::clone(&gateway); + let config = Arc::clone(&config); + let ac = Actor::new(); + ac.boot(syndicate::name!(parent: None, "unix"), + move |t| Ok(t.state.linked_task( + tracing::Span::current(), + { + let ac = t.actor.clone(); + async move { + tracing::info!(protocol = display("unix"), peer = debug(addr)); + let (i, o) = stream.into_split(); + run_connection(ac, + relay::Input::Bytes(Box::pin(i)), + relay::Output::Bytes(Box::pin(o)), + gateway, + config) + } + }))); + } +} + +async fn bind_unix_listener(path: &PathBuf) -> Result { + match UnixListener::bind(path) { + Ok(s) => Ok(s), + Err(e) if e.kind() == io::ErrorKind::AddrInUse => { + // Potentially-stale socket file sitting around. Try + // connecting to it to see if it is alive, and remove it + // if not. + match UnixStream::connect(path).await { + Ok(_probe) => Err(e)?, // Someone's already there! Give up. + Err(f) if f.kind() == io::ErrorKind::ConnectionRefused => { + // Try to steal the socket. + tracing::info!("Cleaning stale socket"); + std::fs::remove_file(path)?; + Ok(UnixListener::bind(path)?) + } + Err(f) => { + tracing::error!(error = debug(f), + "Problem while probing potentially-stale socket"); + return Err(e)? // signal the *original* error, not the probe error + } + } + }, + Err(e) => Err(e)?, } } diff --git a/src/config.rs b/src/config.rs index b76f6d9..69324b5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,10 +1,15 @@ use structopt::StructOpt; +use std::path::PathBuf; + #[derive(Clone, StructOpt)] pub struct ServerConfig { #[structopt(short = "p", long = "port", default_value = "8001")] pub ports: Vec, + #[structopt(short = "s", long = "socket")] + pub sockets: Vec, + #[structopt(long, default_value = "10000")] pub overload_threshold: usize, #[structopt(long, default_value = "5")]