diff --git a/.gitignore b/.gitignore index aa085cd..06581e3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ Cargo.lock /target **/*.rs.bk +scratch/ diff --git a/Cargo.toml b/Cargo.toml index 715f9f5..91b14c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,12 @@ authors = ["Tony Garnock-Jones "] edition = "2018" [dependencies] -preserves = "0.1.0" +preserves = "0.1.3" + serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" + +tokio = "0.2.0-alpha" +bytes = "0.4.12" + +futures-preview = { version = "=0.3.0-alpha.18", features = ["async-await", "nightly"] } diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..bf867e0 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly diff --git a/src/main.rs b/src/main.rs index 9275605..e73c4fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,40 +1,225 @@ +#![recursion_limit="256"] + mod bag; mod skeleton; -use std::net::{TcpListener, TcpStream}; -use std::io::Result; +use bytes::BytesMut; use preserves::value; +use std::collections::BTreeMap; +use std::sync::Arc; +use tokio::prelude::*; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc::{channel, Sender, Receiver}; +use tokio::codec::{Framed, Encoder, Decoder}; +use futures::select; // use self::skeleton::Index; +type ConnId = u64; + mod packets { #[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct Error(pub String); } -fn handle_connection(mut stream: TcpStream) -> Result<()> { - println!("Got {:?}", &stream); - let codec = value::Codec::without_placeholders(); - loop { - match codec.decode(&stream) { - Ok(v) => codec.encoder(&mut stream).write(&v)?, - Err(value::codec::Error::Eof) => break, - Err(value::codec::Error::Io(e)) => return Err(e), - Err(value::codec::Error::Syntax(s)) => { - let v = value::to_value(packets::Error(s.to_string())).unwrap(); - codec.encoder(&mut stream).write(&v)?; - break +#[derive(Debug)] +pub enum RelayMessage { + Hello(ConnId, Sender>), + Speak(ConnId, value::AValue), + Goodbye(ConnId), +} + +#[derive(Debug, Clone, serde::Serialize)] +pub enum PeerMessage { + Join(ConnId), + Speak(ConnId, value::AValue), + Leave(ConnId), +} + +struct ValueCodec { + codec: value::Codec, +} + +impl ValueCodec { + fn new(codec: value::Codec) -> Self { + ValueCodec { codec } + } +} + +impl Decoder for ValueCodec { + type Item = value::AValue; + type Error = value::decoder::Error; + fn decode(&mut self, bs: &mut BytesMut) -> Result, Self::Error> { + let mut buf = &bs[..]; + let orig_len = buf.len(); + let res = self.codec.decode(&mut buf); + let final_len = buf.len(); + bs.advance(orig_len - final_len); + match res { + Ok(v) => Ok(Some(v)), + Err(value::codec::Error::Eof) => Ok(None), + Err(e) => Err(e), + } + } +} + +impl Encoder for ValueCodec { + type Item = value::AValue; + type Error = value::encoder::Error; + fn encode(&mut self, item: Self::Item, bs: &mut BytesMut) -> Result<(), Self::Error> { + bs.extend(self.codec.encode_bytes(&item)?); + Ok(()) + } +} + +struct Peer { + id: ConnId, + rx: Receiver>, + relay: Sender, + frames: Framed, +} + +impl Peer { + async fn new(id: ConnId, mut relay: Sender, stream: TcpStream) -> Self { + let (tx, rx) = channel(1); + let frames = Framed::new(stream, ValueCodec::new(value::Codec::without_placeholders())); + relay.send(RelayMessage::Hello(id, tx)).await.unwrap(); + Peer{ id, rx, relay, frames } + } + + async fn run(&mut self) -> Result<(), std::io::Error> { + println!("Got {:?} {:?}", self.id, &self.frames.get_ref()); + let mut running = true; + while running { + let mut to_send = Vec::new(); + select! { + frame = self.frames.next().boxed().fuse() => match frame { + Some(res) => match res { + Ok(v) => { + if (v.value().as_symbol() == Some(&"die".to_string())) { + panic!(); + } else { + self.relay.send(RelayMessage::Speak(self.id, v)).await.unwrap() + } + } + Err(value::codec::Error::Eof) => running = false, + Err(value::codec::Error::Io(e)) => return Err(e), + Err(value::codec::Error::Syntax(s)) => { + let v = value::to_value(packets::Error(s.to_string())).unwrap(); + to_send.push(v); + running = false; + } + } + None => running = false, + }, + msgopt = self.rx.recv().boxed().fuse() => { + println!("MSGOPT {:?}", &msgopt); + match msgopt { + Some(msg) => to_send.push(value::to_value(&*msg).unwrap()), + None => /* weird. */ running = false, + } + }, + } + for v in to_send { self.frames.send(v).await?; } + } + Ok(()) + } +} + +impl Drop for Peer { + fn drop(&mut self) { + let mut relay = self.relay.clone(); + let id = self.id; + tokio::spawn(async move { + let _ = relay.send(RelayMessage::Goodbye(id)).await; + }); + } +} + +struct Relay { + rx: Receiver, + peers: BTreeMap>>, + pending: Vec, +} + +impl Relay { + fn new(rx: Receiver) -> Self { + Relay { rx, peers: BTreeMap::new(), pending: Vec::new() } + } + + async fn send(&mut self, i: ConnId, s: &mut Sender>, m: &Arc) + -> bool + { + match s.send(Arc::clone(m)).await { + Ok(_) => true, + Err(_) => { self.remove(i); false } + } + } + + fn remove(&mut self, i: ConnId) { + self.peers.remove(&i); + self.pending.push(PeerMessage::Leave(i)); + } + + async fn broadcast(&mut self, m: &Arc) { + for (i, ref mut s) in self.peers.clone() { + self.send(i, s, m).await; + } + } + + async fn run(&mut self) { + loop { + println!("Relay waiting for message ({} connected)", self.peers.len()); + let msg = self.rx.recv().await.unwrap(); + println!("Relay: {:?}", msg); + match msg { + RelayMessage::Hello(i, mut s) => { + let mut ok = true; + let i_join = &Arc::new(PeerMessage::Join(i)); + for (p, ref mut r) in self.peers.clone() { + ok = ok && self.send(i, &mut s, &Arc::new(PeerMessage::Join(p))).await; + self.send(p, r, i_join).await; + } + ok = ok && self.send(i, &mut s, i_join).await; + if ok { + self.peers.insert(i, s); + } + } + RelayMessage::Speak(i, v) => { + self.broadcast(&Arc::new(PeerMessage::Speak(i, v))).await; + } + RelayMessage::Goodbye(i) => self.remove(i), + } + while let Some(m) = self.pending.pop() { + self.broadcast(&Arc::new(m)).await; } } } - Ok(()) } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<(), Box> { // let i = Index::new(); - let listener = TcpListener::bind("0.0.0.0:5889")?; - for stream in listener.incoming() { - handle_connection(stream?); + + // Unlike std channels, a zero buffer is not supported + let (tx, rx) = channel(100); // but ugh a big buffer is needed to avoid deadlocks??? + tokio::spawn(async { + Relay::new(rx).run().await; + }); + + let mut id = 0; + + let mut listener = TcpListener::bind("0.0.0.0:5889").await?; + loop { + let (stream, addr) = listener.accept().await?; + let tx = tx.clone(); + let connid = id; + id = id + 1; + tokio::spawn(async move { + match Peer::new(connid, tx, stream).await.run().await { + Ok(_) => (), + Err(e) => println!("Connection {:?} died with {:?}", addr, e), + } + }); } - Ok(()) }