use syndicate::{config, spaces, packets, ConnId, V, Syndicate}; use syndicate::peer::{Peer, ResultC2S}; use preserves::value; use std::sync::{Mutex, Arc}; use futures::{SinkExt, StreamExt}; use tracing::{Level, error, info, trace}; use tracing_futures::Instrument; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_util::codec::Framed; use tungstenite::Message; use structopt::StructOpt; // for from_args in main type UnitAsyncResult = Result<(), std::io::Error>; fn other_eio(e: E) -> std::io::Error { std::io::Error::new(std::io::ErrorKind::Other, e.to_string()) } fn translate_sink_err(e: tungstenite::Error) -> packets::EncodeError { packets::EncodeError::Write(other_eio(e)) } fn encode_message(codec: &value::Codec, p: packets::S2C) -> Result { let v: V = value::to_value(p)?; Ok(Message::Binary(codec.encode_bytes(&v)?)) } fn message_encoder(codec: &value::Codec) -> impl Fn(packets::S2C) -> futures::future::Ready> + '_ { return move |p| futures::future::ready(encode_message(codec, p)); } fn message_decoder(codec: &value::Codec) -> impl Fn(Result) -> ResultC2S + '_ { return move |r| { loop { return match r { Ok(ref m) => match m { Message::Text(_) => Err(packets::DecodeError::Read( value::decoder::Error::Syntax("Text websocket frames are not accepted"))), Message::Binary(ref bs) => { let mut buf = &bs[..]; match codec.decode(&mut buf) { Ok(v) => if buf.len() > 0 { Err(packets::DecodeError::Read( value::decoder::Error::Io( std::io::Error::new(std::io::ErrorKind::Other, format!("{} trailing bytes", buf.len()))))) } else { value::from_value(&v).map_err(|e| packets::DecodeError::Parse(e, v)) } Err(value::decoder::Error::Eof) => Err(packets::DecodeError::Read( value::decoder::Error::Io( std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "short packet")))), Err(e) => Err(e.into()), } } Message::Ping(_) => continue, // pings are handled by tungstenite before we see them Message::Pong(_) => continue, // unsolicited pongs are to be ignored Message::Close(_) => Err(packets::DecodeError::Read(value::decoder::Error::Eof)), } Err(tungstenite::Error::Io(e)) => Err(e.into()), Err(e) => Err(packets::DecodeError::Read(value::decoder::Error::Io(other_eio(e)))), } } }; } async fn run_connection(connid: ConnId, mut stream: TcpStream, spaces: Arc>, addr: std::net::SocketAddr, config: config::ServerConfigRef) -> UnitAsyncResult { let mut buf = [0; 1]; // peek at the first byte to see what kind of connection to expect match stream.peek(&mut buf).await? { 1 => match buf[0] { 71 /* ASCII 'G' for "GET" */ => { info!(protocol = display("websocket"), peer = debug(addr)); let s = tokio_tungstenite::accept_async(stream).await .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; let (o, i) = s.split(); let codec = packets::standard_preserves_codec(); let i = i.map(message_decoder(&codec)); let o = o.sink_map_err(translate_sink_err).with(message_encoder(&codec)); let mut p = Peer::new(connid, i, o); p.run(spaces, &config).await? }, _ => { info!(protocol = display("raw"), peer = debug(addr)); let (o, i) = Framed::new(stream, packets::Codec::standard()).split(); let mut p = Peer::new(connid, i, o); p.run(spaces, &config).await? } } 0 => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "closed before starting")), _ => unreachable!() } Ok(()) } static NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1); async fn run_listener(spaces: Arc>, port: u16, config: config::ServerConfigRef) -> UnitAsyncResult { let mut listener = TcpListener::bind(format!("0.0.0.0:{}", port)).await?; loop { let (stream, addr) = listener.accept().await?; let id = NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let spaces = Arc::clone(&spaces); let config = Arc::clone(&config); if let Some(n) = config.recv_buffer_size { stream.set_recv_buffer_size(n)?; } if let Some(n) = config.send_buffer_size { stream.set_send_buffer_size(n)?; } tokio::spawn(async move { match run_connection(id, stream, spaces, addr, config).await { Ok(()) => info!("closed"), Err(e) => info!(error = display(e), "closed"), } }.instrument(tracing::info_span!("connection", id))); } } async fn periodic_tasks(spaces: Arc>) -> UnitAsyncResult { let interval = core::time::Duration::from_secs(10); let mut delay = tokio::time::interval(interval); loop { delay.next().await.unwrap(); { let mut spaces = spaces.lock().unwrap(); spaces.cleanup(); spaces.dump_stats(interval); } } } #[tokio::main] async fn main() -> Result<(), Box> { let filter = tracing_subscriber::filter::EnvFilter::from_default_env() .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()); let subscriber = tracing_subscriber::FmtSubscriber::builder() .with_ansi(true) .with_max_level(Level::TRACE) .with_env_filter(filter) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("Could not set tracing global subscriber"); let config = Arc::new(config::ServerConfig::from_args()); let spaces = Arc::new(Mutex::new(spaces::Spaces::new())); let mut daemons = Vec::new(); { let spaces = Arc::clone(&spaces); tokio::spawn(async move { periodic_tasks(spaces).await }); } trace!("startup"); for port in config.ports.clone() { let spaces = Arc::clone(&spaces); let config = Arc::clone(&config); daemons.push(tokio::spawn(async move { info!(port, "listening"); match run_listener(spaces, port, config).await { Ok(()) => (), Err(e) => { error!("{}", e); std::process::exit(2) } } }.instrument(tracing::info_span!("listener", port)))); } futures::future::join_all(daemons).await; Ok(()) }