From 677586827267820950a43a38bef40057545b5640 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Tue, 17 Sep 2019 09:43:03 +0100 Subject: [PATCH] Better errors; Symbol wrapper; proper identifier decoding; Test case struct deserialization works --- implementations/rust/src/lib.rs | 7 ++- implementations/rust/src/symbol.rs | 18 ++++++ implementations/rust/src/value/de.rs | 77 ++++++++++++++----------- implementations/rust/src/value/error.rs | 23 +++++++- implementations/rust/src/value/ser.rs | 4 +- 5 files changed, 90 insertions(+), 39 deletions(-) create mode 100644 implementations/rust/src/symbol.rs diff --git a/implementations/rust/src/lib.rs b/implementations/rust/src/lib.rs index f2aadb7..f0149b3 100644 --- a/implementations/rust/src/lib.rs +++ b/implementations/rust/src/lib.rs @@ -1,4 +1,5 @@ pub mod value; +pub mod symbol; #[cfg(test)] mod ieee754_section_5_10_total_order_tests { @@ -205,6 +206,7 @@ mod decoder_tests { #[cfg(test)] mod samples_tests { + use crate::symbol::Symbol; use crate::value::Decoder; use crate::value::{Value, AValue}; use crate::value::DecodePlaceholderMap; @@ -218,8 +220,7 @@ mod samples_tests { #[derive(Debug, serde::Serialize, serde::Deserialize)] struct TestCases { decode_placeholders: ExpectedPlaceholderMapping, - tests: BTreeMap + tests: BTreeMap } #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -250,12 +251,14 @@ mod samples_tests { #[test] fn simple_to_value() { #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] struct SimpleValue<'a>(String, + Symbol, &'a str, #[serde(with = "serde_bytes")] &'a [u8], #[serde(with = "serde_bytes")] Vec, i16, AValue); let v = SimpleValue("hello".to_string(), + Symbol("sym".to_string()), "world", &b"slice"[..], b"vec".to_vec(), diff --git a/implementations/rust/src/symbol.rs b/implementations/rust/src/symbol.rs new file mode 100644 index 0000000..dbe8e4f --- /dev/null +++ b/implementations/rust/src/symbol.rs @@ -0,0 +1,18 @@ +use crate::value::Value; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct Symbol(pub String); + +impl serde::Serialize for Symbol { + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + Value::symbol(&self.0).serialize(serializer) + } +} + +impl<'de> serde::Deserialize<'de> for Symbol { + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { + let v = Value::deserialize(deserializer)?; + let s = v.as_symbol().ok_or(serde::de::Error::custom("Expected symbol"))?; + Ok(Symbol(s.clone())) + } +} diff --git a/implementations/rust/src/value/de.rs b/implementations/rust/src/value/de.rs index 91dd2a7..f7c5240 100644 --- a/implementations/rust/src/value/de.rs +++ b/implementations/rust/src/value/de.rs @@ -1,6 +1,6 @@ use crate::value::{Value, AValue, Dictionary}; use crate::value::value::{Float, Double}; -use crate::value::error::{Error, Result}; +use crate::value::error::{Error, Result, ExpectedKind}; use num::traits::cast::ToPrimitive; use serde::Deserialize; use serde::de::{Visitor, SeqAccess, MapAccess, EnumAccess, VariantAccess, DeserializeSeed}; @@ -21,6 +21,12 @@ impl<'de> Deserializer<'de> { pub fn from_value(v: &'de AValue) -> Self { Deserializer{ input: v } } + + fn check<'a, T, F>(&'a mut self, f: F, k: ExpectedKind) -> Result + where F: FnOnce(&'de Value) -> Option + { + f(self.input.value()).ok_or_else(|| Error::Expected(k, self.input.clone())) + } } impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { @@ -56,89 +62,89 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { } else if v.is_simple_record("tuple", None) { visitor.visit_seq(VecSeq::new(self, v.as_simple_record("tuple", None).unwrap())) } else { - Err(Error::Syntax) + Err(Error::CannotDeserializeAny) } Value::Sequence(ref v) => visitor.visit_seq(VecSeq::new(self, v)), Value::Dictionary(ref d) => visitor.visit_map(DictMap::new(self, d)), - _ => Err(Error::Syntax), + _ => Err(Error::CannotDeserializeAny), } } fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_bool(self.input.value().as_boolean().ok_or(Error::Syntax)?) + visitor.visit_bool(self.check(|v| v.as_boolean(), ExpectedKind::Boolean)?) } fn deserialize_i8(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_i8(i.to_i8().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_i16(i.to_i16().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_i32(i.to_i32().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_i64(i.to_i64().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_u8(i.to_u8().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_u16(i.to_u16().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_u32(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_u32(i.to_u32().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.input.value().as_signedinteger().ok_or(Error::Syntax)?; + let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; visitor.visit_u64(i.to_u64().ok_or(Error::NumberTooLarge(i.clone()))?) } fn deserialize_f32(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_f32(self.input.value().as_float().ok_or(Error::Syntax)?) + visitor.visit_f32(self.check(|v| v.as_float(), ExpectedKind::Float)?) } fn deserialize_f64(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_f64(self.input.value().as_double().ok_or(Error::Syntax)?) + visitor.visit_f64(self.check(|v| v.as_double(), ExpectedKind::Double)?) } fn deserialize_char(self, visitor: V) -> Result where V: Visitor<'de> { - let fs = - self.input.value().as_simple_record("UnicodeScalar", Some(1)).ok_or(Error::Syntax)?; - if fs.len() != 1 { return Err(Error::Syntax) } - let c = fs[0].value().as_u32().ok_or(Error::Syntax)?; + let fs = self.check(|v| v.as_simple_record("UnicodeScalar", Some(1)), + ExpectedKind::SimpleRecord("UnicodeScalar", Some(1)))?; + let c = fs[0].value().as_u32() + .ok_or_else(|| Error::Expected(ExpectedKind::SignedInteger, self.input.clone()))?; visitor.visit_char(char::try_from(c).or(Err(Error::InvalidUnicodeScalar(c)))?) } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_borrowed_str(&self.input.value().as_string().ok_or(Error::Syntax)?) + visitor.visit_borrowed_str(&self.check(|v| v.as_string(), ExpectedKind::String)?) } fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de> @@ -148,12 +154,12 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_borrowed_bytes(&self.input.value().as_bytestring().ok_or(Error::Syntax)?) + visitor.visit_borrowed_bytes(&self.check(|v| v.as_bytestring(), ExpectedKind::ByteString)?) } fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_byte_buf(self.input.value().as_bytestring().ok_or(Error::Syntax)?.clone()) + visitor.visit_byte_buf(self.check(|v| v.as_bytestring(), ExpectedKind::ByteString)?.clone()) } fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de> @@ -165,7 +171,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { self.input = &fs[0]; visitor.visit_some(self) } - None => Err(Error::Syntax), + None => Err(Error::Expected(ExpectedKind::Option, self.input.clone())) } } } @@ -175,7 +181,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { if self.input.value().is_simple_record("tuple", Some(0)) { visitor.visit_unit() } else { - Err(Error::Syntax) + Err(Error::Expected(ExpectedKind::SimpleRecord("tuple", Some(0)), self.input.clone())) } } @@ -185,7 +191,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { if self.input.value().is_simple_record(name, Some(0)) { visitor.visit_unit() } else { - Err(Error::Syntax) + Err(Error::Expected(ExpectedKind::SimpleRecord(name, Some(0)), self.input.clone())) } } @@ -195,35 +201,38 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { if name == crate::value::value::MAGIC { let mut buf: Vec = Vec::new(); crate::value::Encoder::new(&mut buf, None).write(self.input) - .or_else(|_e| Err(Error::Message("Internal error".to_string())))?; + .or_else(|_e| Err(Error::InternalMagicError))?; visitor.visit_byte_buf(buf) } else { - let fs = self.input.value().as_simple_record(name, Some(1)).ok_or(Error::Syntax)?; + let fs = self.check(|v| v.as_simple_record(name, Some(1)), + ExpectedKind::SimpleRecord(name, Some(1)))?; self.input = &fs[0]; visitor.visit_newtype_struct(self) } } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de> { - let vs = self.input.value().as_sequence().ok_or(Error::Syntax)?; + let vs = self.check(|v| v.as_sequence(), ExpectedKind::Sequence)?; visitor.visit_seq(VecSeq::new(self, vs)) } fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de> { - let fs = self.input.value().as_simple_record("tuple", Some(len)).ok_or(Error::Syntax)?; + let fs = self.check(|v| v.as_simple_record("tuple", Some(len)), + ExpectedKind::SimpleRecord("tuple", Some(len)))?; visitor.visit_seq(VecSeq::new(self, fs)) } fn deserialize_tuple_struct(self, name: &'static str, len: usize, visitor: V) -> Result where V: Visitor<'de> { - let fs = self.input.value().as_simple_record(name, Some(len)).ok_or(Error::Syntax)?; + let fs = self.check(|v| v.as_simple_record(name, Some(len)), + ExpectedKind::SimpleRecord(name, Some(len)))?; visitor.visit_seq(VecSeq::new(self, fs)) } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de> { - let d = self.input.value().as_dictionary().ok_or(Error::Syntax)?; + let d = self.check(|v| v.as_dictionary(), ExpectedKind::Dictionary)?; visitor.visit_map(DictMap::new(self, d)) } @@ -233,8 +242,8 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor: V) -> Result where V: Visitor<'de> { - let fs = - self.input.value().as_simple_record(name, Some(fields.len())).ok_or(Error::Syntax)?; + let fs = self.check(|v| v.as_simple_record(name, Some(fields.len())), + ExpectedKind::SimpleRecord(name, Some(fields.len())))?; visitor.visit_seq(VecSeq::new(self, fs)) } @@ -249,7 +258,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_identifier(self, visitor: V) -> Result where V: Visitor<'de> { - self.deserialize_str(visitor) + visitor.visit_borrowed_str(&self.check(|v| v.as_symbol(), ExpectedKind::Symbol)?) } fn deserialize_ignored_any(self, visitor: V) -> Result where V: Visitor<'de> @@ -330,8 +339,8 @@ impl<'a, 'de> EnumAccess<'de> for &'a mut Deserializer<'de> { fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> where V: DeserializeSeed<'de> { + let (lp, _) = self.check(|v| v.as_record(), ExpectedKind::Record)?; let v = self.input; - let (lp, _) = v.value().as_record().ok_or(Error::Syntax)?; self.input = lp; let variant = seed.deserialize(&mut *self)?; self.input = v; diff --git a/implementations/rust/src/value/error.rs b/implementations/rust/src/value/error.rs index df52c59..42df502 100644 --- a/implementations/rust/src/value/error.rs +++ b/implementations/rust/src/value/error.rs @@ -1,11 +1,32 @@ use num::bigint::BigInt; +use crate::value::AValue; #[derive(Debug)] pub enum Error { Message(String), - Syntax, InvalidUnicodeScalar(u32), NumberTooLarge(BigInt), + CannotDeserializeAny, + Expected(ExpectedKind, AValue), + InternalMagicError, +} + +#[derive(Debug)] +pub enum ExpectedKind { + Boolean, + Float, + Double, + + SignedInteger, + String, + ByteString, + Symbol, + + Record, + SimpleRecord(&'static str, Option), + Option, + Sequence, + Dictionary, } pub type Result = std::result::Result; diff --git a/implementations/rust/src/value/ser.rs b/implementations/rust/src/value/ser.rs index d3a921b..a72b39c 100644 --- a/implementations/rust/src/value/ser.rs +++ b/implementations/rust/src/value/ser.rs @@ -115,8 +115,8 @@ impl serde::Serializer for Serializer { { if name == crate::value::value::MAGIC { let v = to_value(value)?; - let buf: &[u8] = v.value().as_bytestring().ok_or(Error::Syntax)?; - crate::value::Decoder::new(buf, None).next().or(Err(Error::Syntax)) + let buf: &[u8] = v.value().as_bytestring().ok_or(Error::InternalMagicError)?; + crate::value::Decoder::new(buf, None).next().or(Err(Error::InternalMagicError)) } else { // TODO: This is apparently discouraged, and we should apparently just serialize `value`? Ok(Value::simple_record(name, vec![to_value(value)?]).wrap())