diff --git a/src/bin/syndicate-server.rs b/src/bin/syndicate-server.rs index 1fdcc86..f44441a 100644 --- a/src/bin/syndicate-server.rs +++ b/src/bin/syndicate-server.rs @@ -52,33 +52,29 @@ fn message_decoder(codec: &value::Codec) return match r { Ok(ref m) => match m { Message::Text(_) => Err(packets::DecodeError::Read( - value::decoder::Error::Syntax("Text websocket frames are not accepted"))), + value::reader::err("Text websocket frames are not accepted"))), Message::Binary(ref bs) => { let mut buf = &bs[..]; - match codec.decode(&mut buf) { - Ok(v) => if buf.len() > 0 { - Err(packets::DecodeError::Read( - value::decoder::Error::Io( - std::io::Error::new(std::io::ErrorKind::Other, - format!("{} trailing bytes", - buf.len()))))) - } else { - value::from_value(&v).map_err(|e| packets::DecodeError::Parse(e, v)) - } - Err(value::decoder::Error::Eof) => - Err(packets::DecodeError::Read( - value::decoder::Error::Io( - std::io::Error::new(std::io::ErrorKind::UnexpectedEof, - "short packet")))), - Err(e) => Err(e.into()), + let mut vs = codec.decode_all(&mut buf)?; + if vs.len() > 1 { + Err(packets::DecodeError::Read( + std::io::Error::new(std::io::ErrorKind::Other, + "Multiple packets in a single message"))) + } else if vs.len() == 0 { + Err(packets::DecodeError::Read( + std::io::Error::new(std::io::ErrorKind::Other, + "Empty message"))) + } else { + value::from_value(&vs[0]) + .map_err(|e| packets::DecodeError::Parse(e, vs.swap_remove(0))) } } Message::Ping(_) => continue, // pings are handled by tungstenite before we see them Message::Pong(_) => continue, // unsolicited pongs are to be ignored - Message::Close(_) => Err(packets::DecodeError::Read(value::decoder::Error::Eof)), + Message::Close(_) => Err(packets::DecodeError::Read(value::reader::eof())), } Err(tungstenite::Error::Io(e)) => Err(e.into()), - Err(e) => Err(packets::DecodeError::Read(value::decoder::Error::Io(other_eio(e)))), + Err(e) => Err(packets::DecodeError::Read(other_eio(e))), } } }; diff --git a/src/packets.rs b/src/packets.rs index 33165b8..e9d2aa5 100644 --- a/src/packets.rs +++ b/src/packets.rs @@ -2,7 +2,7 @@ use super::V; use super::Syndicate; use bytes::{Buf, buf::BufMutExt, BytesMut}; -use preserves::{value, ser::Serializer}; +use preserves::{value, ser::Serializer, value::Reader}; use std::io; use std::sync::Arc; use std::marker::PhantomData; @@ -50,12 +50,6 @@ pub enum DecodeError { Parse(value::error::Error, V), } -impl From for DecodeError { - fn from(v: value::decoder::Error) -> Self { - DecodeError::Read(v) - } -} - impl From for DecodeError { fn from(v: io::Error) -> Self { DecodeError::Read(v.into()) @@ -109,6 +103,7 @@ impl std::fmt::Display for EncodeError { impl std::error::Error for EncodeError { } + //--------------------------------------------------------------------------- pub struct Codec { @@ -148,18 +143,19 @@ impl tokio_util::codec::Decoder for Code fn decode(&mut self, bs: &mut BytesMut) -> Result, Self::Error> { let mut buf = &bs[..]; let orig_len = buf.len(); - let res = self.codec.decode(&mut buf); - let final_len = buf.len(); - match res { - Ok(v) => { - bs.advance(orig_len - final_len); + let mut d = self.codec.decoder(&mut buf); + match d.next() { + None => Ok(None), + Some(res) => { + let v = res?; + let buffered_len = d.read.buffered_len()?; + let final_len = buf.len(); + bs.advance(orig_len - final_len - buffered_len); match value::from_value(&v) { Ok(p) => Ok(Some(p)), Err(e) => Err(DecodeError::Parse(e, v)) } } - Err(value::decoder::Error::Eof) => Ok(None), - Err(e) => Err(DecodeError::Read(e)), } } } diff --git a/src/peer.rs b/src/peer.rs index c6d602c..de1595a 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -121,14 +121,16 @@ where I: Stream + Send, } } } - Err(packets::DecodeError::Read(value::decoder::Error::Eof)) => { - tracing::trace!("eof"); - running = false; - } - Err(packets::DecodeError::Read(value::decoder::Error::Io(e))) => return Err(e), - Err(packets::DecodeError::Read(value::decoder::Error::Syntax(s))) => { - to_send.push(err(s, value::Value::from(false).wrap())); - running = false; + Err(packets::DecodeError::Read(e)) => { + if value::is_eof_error(&e) { + tracing::trace!("eof"); + running = false; + } else if value::is_syntax_error(&e) { + to_send.push(err(&e.to_string(), value::Value::from(false).wrap())); + running = false; + } else { + return Err(e) + } } Err(packets::DecodeError::Parse(e, v)) => { to_send.push(err(&format!("Packet deserialization error: {}", e), v));