From b5f4c3a498a6b843e091b3c8e891207fea4bc030 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Wed, 27 May 2020 09:04:55 +0200 Subject: [PATCH] Remove placeholders from spec and implementations 2/5 Enormous refactoring in Rust implementation. Direct deserialization. Zero-copy deserialization in some cases. Much faster. --- implementations/rust/Cargo.toml | 1 + implementations/rust/Makefile | 4 + implementations/rust/src/de.rs | 338 +++++++ implementations/rust/src/error.rs | 129 ++- implementations/rust/src/lib.rs | 379 ++++++-- implementations/rust/src/ser.rs | 44 +- implementations/rust/src/set.rs | 24 + implementations/rust/src/symbol.rs | 6 +- implementations/rust/src/value/codec.rs | 49 - implementations/rust/src/value/constants.rs | 6 + implementations/rust/src/value/de.rs | 204 ++-- implementations/rust/src/value/decoder.rs | 47 +- implementations/rust/src/value/encoder.rs | 80 +- implementations/rust/src/value/magic.rs | 21 +- implementations/rust/src/value/mod.rs | 10 +- implementations/rust/src/value/reader.rs | 989 +++++++++++++++----- implementations/rust/src/value/value.rs | 155 ++- implementations/rust/src/value/writer.rs | 5 - 18 files changed, 1842 insertions(+), 649 deletions(-) create mode 100644 implementations/rust/src/set.rs delete mode 100644 implementations/rust/src/value/codec.rs diff --git a/implementations/rust/Cargo.toml b/implementations/rust/Cargo.toml index 5c45cbe..a678e15 100644 --- a/implementations/rust/Cargo.toml +++ b/implementations/rust/Cargo.toml @@ -16,3 +16,4 @@ num = "0.2" num_enum = "0.4.1" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" +lazy_static = "1.4.0" diff --git a/implementations/rust/Makefile b/implementations/rust/Makefile index 32d645b..a87f31d 100644 --- a/implementations/rust/Makefile +++ b/implementations/rust/Makefile @@ -7,3 +7,7 @@ clippy-watch: inotifytest: inotifytest sh -c 'reset; cargo build && RUST_BACKTRACE=1 cargo test -- --nocapture' + +debug-tests: + cargo test --no-run + gdb --args $$(cargo test 3>&1 1>&2 2>&3 3>&- | grep Running | awk '{print $$2}') --test-threads=1 diff --git a/implementations/rust/src/de.rs b/implementations/rust/src/de.rs index e69de29..8d50aff 100644 --- a/implementations/rust/src/de.rs +++ b/implementations/rust/src/de.rs @@ -0,0 +1,338 @@ +use serde::Deserialize; +use serde::de::{Visitor, SeqAccess, MapAccess, EnumAccess, VariantAccess, DeserializeSeed}; +use std::borrow::Cow; +use std::marker::PhantomData; +use super::value::reader::{Reader, BinaryReader, IOBinarySource, BytesBinarySource, CompoundBody}; + +pub use super::error::Error; + +pub type Result = std::result::Result; + +pub struct Deserializer<'de, 'r, R: Reader<'de>> { + pub read: &'r mut R, + phantom: PhantomData<&'de ()>, +} + +pub fn from_bytes<'de, T>(bytes: &'de [u8]) -> + Result +where + T: Deserialize<'de> +{ + from_reader(&mut BinaryReader::new(BytesBinarySource::new(bytes))) +} + +pub fn from_read<'de, 'r, IOR: std::io::Read, T>(read: &'r mut IOR) -> + Result +where + T: Deserialize<'de> +{ + from_reader(&mut BinaryReader::new(IOBinarySource::new(read))) +} + +pub fn from_reader<'r, 'de, R: Reader<'de>, T>(read: &'r mut R) -> + Result +where + T: Deserialize<'de> +{ + let mut de = Deserializer::from_reader(read); + let t = T::deserialize(&mut de)?; + Ok(t) +} + +impl<'r, 'de, R: Reader<'de>> Deserializer<'de, 'r, R> { + pub fn from_reader(read: &'r mut R) -> Self { + Deserializer { read, phantom: PhantomData } + } +} + +impl<'r, 'de, 'a, R: Reader<'de>> serde::de::Deserializer<'de> for &'a mut Deserializer<'de, 'r, R> +{ + type Error = Error; + + fn deserialize_any(self, _visitor: V) -> Result where V: Visitor<'de> + { + // Won't support this here -- use value::de::Deserializer for this + Err(Error::CannotDeserializeAny) + } + + fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_bool(self.read.next_boolean()?) + } + + fn deserialize_i8(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_i8(self.read.next_i8()?) + } + + fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_i16(self.read.next_i16()?) + } + + fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_i32(self.read.next_i32()?) + } + + fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_i64(self.read.next_i64()?) + } + + fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_u8(self.read.next_u8()?) + } + + fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_u16(self.read.next_u16()?) + } + + fn deserialize_u32(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_u32(self.read.next_u32()?) + } + + fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_u64(self.read.next_u64()?) + } + + fn deserialize_f32(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_f32(self.read.next_float()?) + } + + fn deserialize_f64(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_f64(self.read.next_double()?) + } + + fn deserialize_char(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_char(self.read.next_char()?) + } + + fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de> + { + match self.read.next_str()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_str(&s), + } + } + + fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de> + { + self.deserialize_str(visitor) + } + + fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de> + { + match self.read.next_bytestring()? { + Cow::Borrowed(bs) => visitor.visit_borrowed_bytes(bs), + Cow::Owned(bs) => visitor.visit_bytes(&bs), + } + } + + fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_byte_buf(self.read.next_bytestring()?.into_owned()) + } + + fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de> + { + let (is_some, mut compound_body) = self.read.open_option()?; + let result = if is_some { + compound_body.ensure_more_expected(self.read)?; + visitor.visit_some(&mut *self)? + } else { + visitor.visit_none::()? + }; + compound_body.ensure_complete(self.read)?; + Ok(result) + } + + fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de> + { + let mut compound_body = self.read.open_simple_record("tuple", Some(0))?; + let result = visitor.visit_unit::()?; + compound_body.ensure_complete(self.read)?; + Ok(result) + } + + fn deserialize_unit_struct(self, name: &'static str, visitor: V) + -> Result where V: Visitor<'de> + { + let mut compound_body = self.read.open_simple_record(name, Some(0))?; + let result = visitor.visit_unit::()?; + compound_body.ensure_complete(self.read)?; + Ok(result) + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) + -> Result where V: Visitor<'de> + { + match super::value::magic::transmit_input_value( + name, || Ok(self.read.demand_next(false)?.into_owned()))? + { + Some(v) => visitor.visit_u64(v), + None => { + let mut compound_body = self.read.open_simple_record(name, Some(1))?; + compound_body.ensure_more_expected(self.read)?; + let result = visitor.visit_newtype_struct(&mut *self)?; + compound_body.ensure_complete(self.read)?; + Ok(result) + } + } + } + + fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de> { + // Hack around serde's model: Deserialize *sets* as sequences, + // too, and reconstruct them as Rust Sets on the visitor side. + let compound_body = self.read.open_sequence_or_set()?; + visitor.visit_seq(Seq::new(self, compound_body)) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de> + { + let compound_body = self.read.open_simple_record("tuple", Some(len))?; + visitor.visit_seq(Seq::new(self, compound_body)) + } + + fn deserialize_tuple_struct(self, name: &'static str, len: usize, visitor: V) + -> Result where V: Visitor<'de> + { + let compound_body = self.read.open_simple_record(name, Some(len))?; + visitor.visit_seq(Seq::new(self, compound_body)) + } + + fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de> { + let compound_body = self.read.open_dictionary()?; + visitor.visit_map(Seq::new(self, compound_body)) + } + + fn deserialize_struct(self, + name: &'static str, + fields: &'static [&'static str], + visitor: V) + -> Result where V: Visitor<'de> + { + let compound_body = self.read.open_simple_record(name, Some(fields.len()))?; + visitor.visit_seq(Seq::new(self, compound_body)) + } + + fn deserialize_enum(self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V) + -> Result where V: Visitor<'de> + { + visitor.visit_enum(self) + } + + fn deserialize_identifier(self, visitor: V) -> Result where V: Visitor<'de> + { + match self.read.next_symbol()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(s) => visitor.visit_str(&s), + } + } + + fn deserialize_ignored_any(self, visitor: V) -> Result where V: Visitor<'de> + { + visitor.visit_none() + } +} + +pub struct Seq<'de, 'r, 'a, R: Reader<'de>> { + de: &'a mut Deserializer<'de, 'r, R>, + compound_body: CompoundBody, +} + +impl<'de, 'r, 'a, R: Reader<'de>> Seq<'de, 'r, 'a, R> { + fn new(de: &'a mut Deserializer<'de, 'r, R>, compound_body: CompoundBody) -> Self { + Seq { de, compound_body } + } + + fn next_item(&mut self, seed: T) -> + Result> where T: DeserializeSeed<'de> + { + match self.compound_body.more_expected(self.de.read)? { + false => Ok(None), + true => Ok(Some(seed.deserialize(&mut *self.de)?)), + } + } +} + +impl<'de, 'r, 'a, R: Reader<'de>> SeqAccess<'de> for Seq<'de, 'r, 'a, R> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> + Result> where T: DeserializeSeed<'de> + { + self.next_item(seed) + } +} + +impl<'de, 'r, 'a, R: Reader<'de>> MapAccess<'de> for Seq<'de, 'r, 'a, R> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> + Result> where K: DeserializeSeed<'de> + { + self.next_item(seed) + } + + fn next_value_seed(&mut self, seed: V) -> + Result where V: DeserializeSeed<'de> + { + match self.next_item(seed)? { + Some(item) => Ok(item), + None => Err(Error::MissingItem), + } + } +} + +impl<'de, 'r, 'a, R: Reader<'de>> EnumAccess<'de> for &'a mut Deserializer<'de, 'r, R> { + type Error = Error; + type Variant = Seq<'de, 'r, 'a, R>; + + fn variant_seed(self, seed: V) + -> Result<(V::Value, Self::Variant)> where V: DeserializeSeed<'de> + { + let mut compound_body = self.read.open_record(None)?; + compound_body.ensure_more_expected(self.read)?; + let variant = seed.deserialize(&mut *self)?; + Ok((variant, Seq::new(self, compound_body))) + } +} + +impl<'de, 'r, 'a, R: Reader<'de>> VariantAccess<'de> for Seq<'de, 'r, 'a, R> { + type Error = Error; + + fn unit_variant(mut self) -> Result<()> { + self.compound_body.ensure_complete(self.de.read) + } + + fn newtype_variant_seed(mut self, seed: T) -> Result where T: DeserializeSeed<'de> { + match self.next_item(seed)? { + None => Err(Error::MissingItem), + Some(v) => { + self.compound_body.ensure_complete(self.de.read)?; + Ok(v) + } + } + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result where V: Visitor<'de> { + visitor.visit_seq(self) + } + + fn struct_variant(self, _fields: &'static [&'static str], visitor: V) + -> Result where V: Visitor<'de> + { + visitor.visit_seq(self) + } +} diff --git a/implementations/rust/src/error.rs b/implementations/rust/src/error.rs index 1509063..2a5adf3 100644 --- a/implementations/rust/src/error.rs +++ b/implementations/rust/src/error.rs @@ -1,30 +1,135 @@ +use num::bigint::BigInt; +use std::convert::From; +use crate::value::IOValue; + #[derive(Debug)] -pub struct Error { - inner: std::io::Error, +pub enum Error { + Io(std::io::Error), + Message(String), + InvalidUnicodeScalar(u32), + NumberOutOfRange(BigInt), + CannotDeserializeAny, + MissingCloseDelimiter, + MissingItem, + Expected(ExpectedKind, Received), } -impl std::convert::From for std::io::Error { - fn from(e: Error) -> Self { - e.inner +#[derive(Debug)] +pub enum Received { + ReceivedSomethingElse, + ReceivedRecordWithLabel(String), + ReceivedOtherValue(IOValue), +} + +#[derive(Debug, PartialEq)] +pub enum ExpectedKind { + Boolean, + Float, + Double, + + SignedInteger, + String, + ByteString, + Symbol, + + Record(Option), + SimpleRecord(&'static str, Option), + Sequence, + Set, + Dictionary, + + SequenceOrSet, // Because of hacking up serde's data model: see open_sequence_or_set etc. + + Option, + UnicodeScalar, +} + +impl From for Error { + fn from(e: std::io::Error) -> Self { + Error::Io(e) } } -impl std::convert::From for Error { - fn from(inner: std::io::Error) -> Self { - Error { inner } +impl From for std::io::Error { + fn from(e: Error) -> Self { + match e { + Error::Io(ioe) => ioe, + Error::Message(str) => std::io::Error::new(std::io::ErrorKind::Other, str), + _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()), + } } } impl serde::ser::Error for Error { fn custom(msg: T) -> Self { - Error { inner: std::io::Error::new(std::io::ErrorKind::Other, msg.to_string()) } + Self::Message(msg.to_string()) } } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { - self.inner.fmt(f) +impl serde::de::Error for Error { + fn custom(msg: T) -> Self { + Self::Message(msg.to_string()) } } impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +//--------------------------------------------------------------------------- + +pub fn is_io_error(e: &Error) -> bool { + if let Error::Io(_) = e { true } else { false } +} + +pub fn eof() -> Error { + Error::Io(io_eof()) +} + +pub fn is_eof_error(e: &Error) -> bool { + if let Error::Io(ioe) = e { + is_eof_io_error(ioe) + } else { + false + } +} + +pub fn syntax_error(s: &str) -> Error { + Error::Io(io_syntax_error(s)) +} + +pub fn is_syntax_error(e: &Error) -> bool { + if let Error::Io(ioe) = e { + is_syntax_io_error(ioe) + } else { + false + } +} + +//--------------------------------------------------------------------------- + +pub fn io_eof() -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "EOF") +} + +pub fn is_eof_io_error(e: &std::io::Error) -> bool { + match e.kind() { + std::io::ErrorKind::UnexpectedEof => true, + _ => false, + } +} + +pub fn io_syntax_error(s: &str) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::InvalidData, s) +} + +pub fn is_syntax_io_error(e: &std::io::Error) -> bool { + match e.kind() { + std::io::ErrorKind::InvalidData => true, + _ => false, + } +} diff --git a/implementations/rust/src/lib.rs b/implementations/rust/src/lib.rs index acc5413..33ff756 100644 --- a/implementations/rust/src/lib.rs +++ b/implementations/rust/src/lib.rs @@ -1,13 +1,18 @@ +#[macro_use] +extern crate lazy_static; + pub mod de; pub mod ser; pub mod error; +pub mod set; pub mod symbol; pub mod value; #[cfg(test)] mod dom { use super::value::{ - Value, IOValue, NestedValue, PlainValue, Domain, Codec, + Value, IOValue, NestedValue, PlainValue, Domain, + encoder::encode_bytes, }; #[derive(Debug, Hash, Clone, Ord, PartialEq, Eq, PartialOrd)] @@ -31,7 +36,7 @@ mod dom { Value::Domain(Dom::One).wrap(), Value::from(2).wrap()]) .wrap(); - assert_eq!(Codec::without_placeholders().encode_bytes(&v.to_io_value()).unwrap(), + assert_eq!(encode_bytes(&v.to_io_value()).unwrap(), [147, 49, 100, 255, 255, 255, 255, 50]); } @@ -40,7 +45,7 @@ mod dom { Value::Domain(Dom::Two).wrap(), Value::from(2).wrap()]) .wrap(); - assert_eq!(Codec::without_placeholders().encode_bytes(&v.to_io_value()).unwrap(), + assert_eq!(encode_bytes(&v.to_io_value()).unwrap(), [147, 49, 120, 68, 111, 109, 58, 58, 84, 119, 111, 50]); } } @@ -227,14 +232,30 @@ mod value_tests { #[cfg(test)] mod decoder_tests { - use crate::value::{Decoder, BinaryReader}; - use crate::value::{Value, NestedValue}; + use crate::value::{Value, NestedValue, decoder}; + use crate::de::from_bytes; + use crate::error::{Error, ExpectedKind, is_eof_io_error}; + + fn expect_number_out_of_range(r: Result) { + match r { + Ok(v) => panic!("Expected NumberOutOfRange, but got a parse of {:?}", v), + Err(Error::NumberOutOfRange(_)) => (), + Err(e) => panic!("Expected NumberOutOfRange, but got an error of {:?}", e), + } + } + + fn expect_expected(k: ExpectedKind, r: Result) { + match r { + Ok(v) => panic!("Expected Expected({:?}), but got a parse of {:?}", k, v), + Err(Error::Expected(k1, _)) if k1 == k => (), + Err(e) => panic!("Expected Expected({:?}, but got an error of {:?}", k, e), + } + } #[test] fn skip_annotations_noskip() { let mut buf = &b"\x0521"[..]; - let r = BinaryReader::new(&mut buf); - let mut d = Decoder::new(r, None); - let v = d.next_or_err().unwrap(); + let mut d = decoder::from_bytes(&mut buf); + let v = d.demand_next().unwrap(); assert_eq!(v.annotations().len(), 1); assert_eq!(v.annotations()[0], Value::from(2).wrap()); assert_eq!(v.value(), &Value::from(1)); @@ -242,54 +263,248 @@ mod decoder_tests { #[test] fn skip_annotations_skip() { let mut buf = &b"\x0521"[..]; - let r = BinaryReader::new(&mut buf); - let mut d = Decoder::new(r, None); + let mut d = decoder::from_bytes(&mut buf); d.set_read_annotations(false); - let v = d.next_or_err().unwrap(); + let v = d.demand_next().unwrap(); assert_eq!(v.annotations().len(), 0); assert_eq!(v.value(), &Value::from(1)); } - #[test] fn two_values_at_once() { + #[test] fn multiple_values_buf_advanced() { let mut buf = &b"\x81tPing\x81tPong"[..]; assert_eq!(buf.len(), 12); - let r = BinaryReader::new(&mut buf); - let mut d = Decoder::new(r, None); - assert_eq!(d.next_or_err().unwrap().value(), &Value::simple_record("Ping", vec![])); - assert_eq!(d.next_or_err().unwrap().value(), &Value::simple_record("Pong", vec![])); - assert_eq!(buf.len(), 0); + let mut d = decoder::from_bytes(&mut buf); + assert_eq!(d.read.source.index, 0); + assert_eq!(d.demand_next().unwrap().value(), &Value::simple_record("Ping", vec![])); + assert_eq!(d.read.source.index, 6); + assert_eq!(d.demand_next().unwrap().value(), &Value::simple_record("Pong", vec![])); + assert_eq!(d.read.source.index, 12); + assert!(if let None = d.next() { true } else { false }); + assert!(if let Err(e) = d.demand_next() { is_eof_io_error(&e) } else { false }); } - #[test] fn buf_advanced() { - let mut buf = &b"\x81tPing\x81tPong"[..]; - assert_eq!(buf.len(), 12); - let mut r = BinaryReader::new(&mut buf); - let mut d = Decoder::new(&mut r, None); - assert_eq!(d.next_or_err().unwrap().value(), &Value::simple_record("Ping", vec![])); - assert_eq!(buf.len(), 6); - let mut r = BinaryReader::new(&mut buf); - let mut d = Decoder::new(&mut r, None); - assert_eq!(d.next_or_err().unwrap().value(), &Value::simple_record("Pong", vec![])); - assert_eq!(buf.len(), 0); + #[test] fn direct_i8_format_a_positive() { assert_eq!(from_bytes::(b"1").unwrap(), 1) } + #[test] fn direct_i8_format_a_zero() { assert_eq!(from_bytes::(b"0").unwrap(), 0) } + #[test] fn direct_i8_format_a_negative() { assert_eq!(from_bytes::(b"?").unwrap(), -1) } + #[test] fn direct_i8_format_b() { assert_eq!(from_bytes::(b"A\xfe").unwrap(), -2) } + #[test] fn direct_i8_format_b_too_long() { assert_eq!(from_bytes::(b"C\xff\xff\xfe").unwrap(), -2) } + #[test] fn direct_i8_format_b_much_too_long() { assert_eq!(from_bytes::(b"J\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe").unwrap(), -2) } + #[test] fn direct_i8_format_c() { assert_eq!(from_bytes::(b"$a\xfe\x04").unwrap(), -2) } + #[test] fn direct_i8_format_c_too_long() { assert_eq!(from_bytes::(b"$a\xffa\xfe\x04").unwrap(), -2) } + #[test] fn direct_i8_format_c_much_too_long() { assert_eq!(from_bytes::(b"$a\xffa\xffa\xffa\xffa\xffa\xffa\xffa\xffa\xffa\xffa\xffa\xfe\x04").unwrap(), -2) } + + #[test] fn direct_u8_format_a_positive() { assert_eq!(from_bytes::(b"1").unwrap(), 1) } + #[test] fn direct_u8_format_a_zero() { assert_eq!(from_bytes::(b"0").unwrap(), 0) } + #[test] fn direct_u8_format_b() { assert_eq!(from_bytes::(b"A1").unwrap(), 49) } + #[test] fn direct_u8_format_b_too_long() { assert_eq!(from_bytes::(b"D\0\0\01").unwrap(), 49) } + #[test] fn direct_u8_format_b_much_too_long() { assert_eq!(from_bytes::(b"J\0\0\0\0\0\0\0\0\01").unwrap(), 49) } + #[test] fn direct_u8_format_c() { assert_eq!(from_bytes::(b"$a1\x04").unwrap(), 49) } + #[test] fn direct_u8_format_c_too_long() { assert_eq!(from_bytes::(b"$a\0a1\x04").unwrap(), 49) } + #[test] fn direct_u8_format_c_much_too_long() { assert_eq!(from_bytes::(b"$a\0a\0a\0a\0a\0a\0a\0a\0a\0a\0a\0a1\x04").unwrap(), 49) } + + #[test] fn direct_i16_format_a() { assert_eq!(from_bytes::(b">").unwrap(), -2) } + #[test] fn direct_i16_format_b() { assert_eq!(from_bytes::(b"B\xfe\xff").unwrap(), -257) } + + #[test] fn direct_u8_wrong_format() { + expect_expected(ExpectedKind::SignedInteger, from_bytes::(b"Ubogus")) + } + + #[test] fn direct_u8_format_b_too_large() { + expect_number_out_of_range(from_bytes::(b"D\0\011")) + } + + #[test] fn direct_u8_format_c_too_large() { + expect_number_out_of_range(from_bytes::(b"$a1a1\x04")) + } + + #[test] fn direct_i8_format_b_too_large() { + expect_number_out_of_range(from_bytes::(b"B\xfe\xff")) + } + + #[test] fn direct_i8_format_c_too_large() { + expect_number_out_of_range(from_bytes::(b"$a\xfea\xff\x04")) + } + + #[test] fn direct_i16_format_b_too_large() { + expect_number_out_of_range(from_bytes::(b"C\xfe\xff\xff")); + } + + #[test] fn direct_i32_format_b_ok() { + assert_eq!(from_bytes::(b"C\xfe\xff\xff").unwrap(), -65537); + } + + #[test] fn direct_i32_format_b_ok_2() { + assert_eq!(from_bytes::(b"D\xfe\xff\xff\xff").unwrap(), -16777217); + } + + #[test] fn direct_i64_format_b() { + assert_eq!(from_bytes::(b"A\xff").unwrap(), -1); + assert_eq!(from_bytes::(b"C\xff\xff\xff").unwrap(), -1); + assert_eq!(from_bytes::(b"J\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff").unwrap(), -1); + assert_eq!(from_bytes::(b"A\xfe").unwrap(), -2); + assert_eq!(from_bytes::(b"C\xff\xfe\xff").unwrap(), -257); + assert_eq!(from_bytes::(b"C\xfe\xff\xff").unwrap(), -65537); + assert_eq!(from_bytes::(b"J\xff\xff\xff\xff\xff\xff\xfe\xff\xff\xff").unwrap(), -16777217); + assert_eq!(from_bytes::(b"J\xff\xff\xfe\xff\xff\xff\xff\xff\xff\xff").unwrap(), -72057594037927937); + expect_number_out_of_range(from_bytes::(b"J\xff\xff\x0e\xff\xff\xff\xff\xff\xff\xff")); + expect_number_out_of_range(from_bytes::(b"I\xff\x0e\xff\xff\xff\xff\xff\xff\xff")); + expect_number_out_of_range(from_bytes::(b"I\x80\x0e\xff\xff\xff\xff\xff\xff\xff")); + expect_number_out_of_range(from_bytes::(b"J\xff\x00\x0e\xff\xff\xff\xff\xff\xff\xff")); + assert_eq!(from_bytes::(b"H\xfe\xff\xff\xff\xff\xff\xff\xff").unwrap(), -72057594037927937); + assert_eq!(from_bytes::(b"H\x0e\xff\xff\xff\xff\xff\xff\xff").unwrap(), 1080863910568919039); + assert_eq!(from_bytes::(b"H\x80\0\0\0\0\0\0\0").unwrap(), -9223372036854775808); + assert_eq!(from_bytes::(b"H\0\0\0\0\0\0\0\0").unwrap(), 0); + assert_eq!(from_bytes::(b"@").unwrap(), 0); + assert_eq!(from_bytes::(b"H\x7f\xff\xff\xff\xff\xff\xff\xff").unwrap(), 9223372036854775807); + } + + #[test] fn direct_u64_format_b() { + expect_number_out_of_range(from_bytes::(b"A\xff")); + assert_eq!(from_bytes::(b"B\0\xff").unwrap(), 255); + expect_number_out_of_range(from_bytes::(b"C\xff\xff\xff")); + assert_eq!(from_bytes::(b"D\0\xff\xff\xff").unwrap(), 0xffffff); + expect_number_out_of_range(from_bytes::(b"J\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")); + assert_eq!(from_bytes::(b"A\x02").unwrap(), 2); + assert_eq!(from_bytes::(b"C\x00\x01\x00").unwrap(), 256); + assert_eq!(from_bytes::(b"C\x01\x00\x00").unwrap(), 65536); + assert_eq!(from_bytes::(b"J\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00").unwrap(), 16777216); + assert_eq!(from_bytes::(b"J\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00").unwrap(), 72057594037927936); + assert_eq!(from_bytes::(b"J\x00\x00\xf2\x00\x00\x00\x00\x00\x00\x00").unwrap(), 0xf200000000000000); + assert_eq!(from_bytes::(b"J\x00\x00\x72\x00\x00\x00\x00\x00\x00\x00").unwrap(), 0x7200000000000000); + expect_number_out_of_range(from_bytes::(b"J\x00\xf2\x00\x00\x00\x00\x00\x00\x00\x00")); + assert_eq!(from_bytes::(b"I\x00\xf2\x00\x00\x00\x00\x00\x00\x00").unwrap(), 0xf200000000000000); + expect_number_out_of_range(from_bytes::(b"I\x7f\xf2\x00\x00\x00\x00\x00\x00\x00")); + expect_number_out_of_range(from_bytes::(b"J\x00\xff\xf2\x00\x00\x00\x00\x00\x00\x00")); + assert_eq!(from_bytes::(b"H\x01\x00\x00\x00\x00\x00\x00\x00").unwrap(), 72057594037927936); + assert_eq!(from_bytes::(b"H\x0e\xff\xff\xff\xff\xff\xff\xff").unwrap(), 1080863910568919039); + expect_number_out_of_range(from_bytes::(b"H\x80\0\0\0\0\0\0\0")); + assert_eq!(from_bytes::(b"I\0\x80\0\0\0\0\0\0\0").unwrap(), 9223372036854775808); + assert_eq!(from_bytes::(b"H\0\0\0\0\0\0\0\0").unwrap(), 0); + assert_eq!(from_bytes::(b"@").unwrap(), 0); + assert_eq!(from_bytes::(b"H\x7f\xff\xff\xff\xff\xff\xff\xff").unwrap(), 9223372036854775807); + } +} + +#[cfg(test)] +mod serde_tests { + use crate::symbol::Symbol; + use crate::de::from_bytes as deserialize_from_bytes; + use crate::value::de::from_value as deserialize_from_value; + use crate::value::encoder::encode_bytes; + use crate::value::to_value; + use crate::value::{Value, IOValue, Map, Set}; + + #[test] fn simple_to_value() { + use serde::Serialize; + #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] + struct Colour{ red: u8, green: u8, blue: u8 } + #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] + struct SimpleValue<'a>(String, + #[serde(with = "crate::symbol")] String, + Symbol, + #[serde(with = "crate::symbol")] String, + Symbol, + &'a str, + #[serde(with = "serde_bytes")] &'a [u8], + #[serde(with = "serde_bytes")] Vec, + Vec, + #[serde(with = "crate::set")] Set, + i16, + IOValue, + Map, + f32, + f64); + let mut str_set = Set::new(); + str_set.insert("one".to_owned()); + str_set.insert("two".to_owned()); + str_set.insert("three".to_owned()); + let mut colours = Map::new(); + colours.insert("red".to_owned(), Colour { red: 255, green: 0, blue: 0 }); + colours.insert("green".to_owned(), Colour { red: 0, green: 255, blue: 0 }); + colours.insert("blue".to_owned(), Colour { red: 0, green: 0, blue: 255 }); + let v = SimpleValue("hello".to_string(), + "sym1".to_string(), + Symbol("sym2".to_string()), + "sym3".to_string(), + Symbol("sym4".to_string()), + "world", + &b"slice"[..], + b"vec".to_vec(), + vec![false, true, false, true], + str_set, + 12345, + Value::from("hi").wrap(), + colours, + 12.345, + 12.3456789); + println!("== v: {:#?}", v); + let w: IOValue = to_value(&v); + println!("== w: {:#?}", w); + let x = deserialize_from_value(&w).unwrap(); + println!("== x: {:#?}", &x); + assert_eq!(v, x); + + let expected_bytes = vec![ + 0x8f, 0x10, // Struct, 15 members + 1 label + 0x7b, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x56, 0x61, 0x6c, 0x75, 0x65, // SimpleValue + 0x55, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // "hello" + 0x74, 0x73, 0x79, 0x6d, 0x31, // sym1 + 0x74, 0x73, 0x79, 0x6d, 0x32, // sym2 + 0x74, 0x73, 0x79, 0x6d, 0x33, // sym3 + 0x74, 0x73, 0x79, 0x6d, 0x34, // sym4 + 0x55, 0x77, 0x6f, 0x72, 0x6c, 0x64, // "world" + 0x65, 0x73, 0x6c, 0x69, 0x63, 0x65, // #"slice" + 0x63, 0x76, 0x65, 0x63, // #"vec" + 0x94, // Sequence, 4 items + 0x0, // false + 0x1, // true + 0x0, // false + 0x1, // true + 0xa3, // Set, 3 items + 0x53, 0x6f, 0x6e, 0x65, + 0x55, 0x74, 0x68, 0x72, 0x65, 0x65, + 0x53, 0x74, 0x77, 0x6f, + 0x42, 0x30, 0x39, // 12345 + 0x52, 0x68, 0x69, // "hi" + 0xb6, // Dictionary, 6 items = 3 key/value pairs + 0x54, 0x62, 0x6c, 0x75, 0x65, // "blue" + 0x84, 0x76, 0x43, 0x6f, 0x6c, 0x6f, 0x75, 0x72, 0x30, 0x30, 0x42, 0x00, 0xff, + 0x55, 0x67, 0x72, 0x65, 0x65, 0x6e, // "green" + 0x84, 0x76, 0x43, 0x6f, 0x6c, 0x6f, 0x75, 0x72, 0x30, 0x42, 0x00, 0xff, 0x30, + 0x53, 0x72, 0x65, 0x64, // "red" + 0x84, 0x76, 0x43, 0x6f, 0x6c, 0x6f, 0x75, 0x72, 0x42, 0x00, 0xff, 0x30, 0x30, + + 0x2, 0x41, 0x45, 0x85, 0x1f, // 12.345, + 0x3, 0x40, 0x28, 0xb0, 0xfc, 0xd3, 0x24, 0xd5, 0xa2, // 12.3456789 + ]; + + let v_bytes_1 = encode_bytes(&w).unwrap(); + println!("== w bytes = {:?}", v_bytes_1); + assert_eq!(expected_bytes, v_bytes_1); + + let mut v_bytes_2 = Vec::new(); + v.serialize(&mut crate::ser::Serializer::new(&mut v_bytes_2)).unwrap(); + println!("== v bytes = {:?}", v_bytes_2); + assert_eq!(v_bytes_1, v_bytes_2); + + let y = deserialize_from_bytes(&v_bytes_1).unwrap(); + println!("== y: {:#?}", &y); + assert_eq!(v, y); } } #[cfg(test)] mod samples_tests { use crate::symbol::Symbol; - use crate::value::{Codec, Decoder, BinaryReader}; - use crate::value::reader::is_syntax_error; - use crate::value::{Value, IOValue, Map}; - use crate::value::DecodePlaceholderMap; - use crate::value::to_value; - use crate::value::from_value; - - #[derive(Debug, serde::Serialize, serde::Deserialize)] - struct ExpectedPlaceholderMapping(DecodePlaceholderMap); + use crate::error::{is_eof_io_error, is_syntax_io_error}; + use crate::value::de::from_value as deserialize_from_value; + use crate::value::decoder; + use crate::value::encoder::encode_bytes; + use crate::value::{IOValue, Map}; + use std::iter::Iterator; #[derive(Debug, serde::Serialize, serde::Deserialize)] struct TestCases { - decode_placeholders: ExpectedPlaceholderMapping, tests: Map } @@ -301,49 +516,55 @@ mod samples_tests { DecodeTest(#[serde(with = "serde_bytes")] Vec, IOValue), ParseError(String), ParseShort(String), + ParseEOF(String), DecodeError(#[serde(with = "serde_bytes")] Vec), DecodeShort(#[serde(with = "serde_bytes")] Vec), + DecodeEOF(#[serde(with = "serde_bytes")] Vec), + } + + fn decode_all<'de>(bytes: &'de [u8]) -> Result, std::io::Error> { + let d = decoder::from_bytes(bytes); + d.collect() } #[test] fn run() -> std::io::Result<()> { let mut fh = std::fs::File::open("../../tests/samples.bin").unwrap(); - let r = BinaryReader::new(&mut fh); - let mut d = Decoder::new(r, None); - let tests: TestCases = from_value(&d.next_or_err().unwrap()).unwrap(); + let mut d = decoder::from_read(&mut fh); + let tests: TestCases = deserialize_from_value(&d.next().unwrap().unwrap()).unwrap(); // println!("{:#?}", tests); - let codec = Codec::new(tests.decode_placeholders.0); for (Symbol(ref name), ref case) in tests.tests { println!("{:?} ==> {:?}", name, case); match case { TestCase::Test(ref bin, ref val) => { - assert_eq!(&codec.decode_all(&mut &codec.encode_bytes(val)?[..])?, &[val.clone()]); - assert_eq!(&codec.decode_all(&mut &bin[..])?, &[val.clone()]); - assert_eq!(&codec.encode_bytes(val)?, bin); + assert_eq!(&decode_all(&encode_bytes(val)?[..])?, &[val.clone()]); + assert_eq!(&decode_all(&bin[..])?, &[val.clone()]); + assert_eq!(&encode_bytes(val)?, bin); } TestCase::NondeterministicTest(ref bin, ref val) => { // The test cases in samples.txt are carefully // written so that while strictly // "nondeterministic", the order of keys in // dictionaries follows Preserves order. - assert_eq!(&codec.decode_all(&mut &codec.encode_bytes(val)?[..])?, &[val.clone()]); - assert_eq!(&codec.decode_all(&mut &bin[..])?, &[val.clone()]); - assert_eq!(&codec.encode_bytes(val)?, bin); + assert_eq!(&decode_all(&encode_bytes(val)?[..])?, &[val.clone()]); + assert_eq!(&decode_all(&bin[..])?, &[val.clone()]); + assert_eq!(&encode_bytes(val)?, bin); } TestCase::StreamingTest(ref bin, ref val) => { - assert_eq!(&codec.decode_all(&mut &codec.encode_bytes(val)?[..])?, &[val.clone()]); - assert_eq!(&codec.decode_all(&mut &bin[..])?, &[val.clone()]); + assert_eq!(&decode_all(&encode_bytes(val)?[..])?, &[val.clone()]); + assert_eq!(&decode_all(&bin[..])?, &[val.clone()]); } TestCase::DecodeTest(ref bin, ref val) => { - assert_eq!(&codec.decode_all(&mut &codec.encode_bytes(val)?[..])?, &[val.clone()]); - assert_eq!(&codec.decode_all(&mut &bin[..])?, &[val.clone()]); + assert_eq!(&decode_all(&encode_bytes(val)?[..])?, &[val.clone()]); + assert_eq!(&decode_all(&bin[..])?, &[val.clone()]); } TestCase::ParseError(_) => (), TestCase::ParseShort(_) => (), + TestCase::ParseEOF(_) => (), TestCase::DecodeError(ref bin) => { - match codec.decode_all(&mut &bin[..]) { + match decode_all(&bin[..]) { Ok(_) => panic!("Unexpected success"), - Err(e) => if is_syntax_error(&e) { + Err(e) => if is_syntax_io_error(&e) { () } else { panic!("Unexpected error {:?}", e) @@ -351,51 +572,17 @@ mod samples_tests { } } TestCase::DecodeShort(ref bin) => { - assert_eq!(codec.decode_all(&mut &bin[..])?.len(), 0); + assert!(if let Err(e) = decoder::from_bytes(bin).next().unwrap() { + is_eof_io_error(&e) + } else { + false + }) + } + TestCase::DecodeEOF(ref bin) => { + assert!(if let None = decoder::from_bytes(bin).next() { true } else { false }) } } } Ok(()) } - - #[test] fn simple_to_value() { - use serde::Serialize; - #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] - struct SimpleValue<'a>(String, - #[serde(with = "crate::symbol")] String, - Symbol, - #[serde(with = "crate::symbol")] String, - Symbol, - &'a str, - #[serde(with = "serde_bytes")] &'a [u8], - #[serde(with = "serde_bytes")] Vec, - i16, - IOValue); - let v = SimpleValue("hello".to_string(), - "sym1".to_string(), - Symbol("sym2".to_string()), - "sym3".to_string(), - Symbol("sym4".to_string()), - "world", - &b"slice"[..], - b"vec".to_vec(), - 12345, - Value::from("hi").wrap()); - println!("== v: {:#?}", v); - let w: IOValue = to_value(&v); - println!("== w: {:#?}", w); - let x = from_value(&w).unwrap(); - println!("== x: {:#?}", &x); - assert_eq!(v, x); - let mut placeholders = Map::new(); - placeholders.insert(5, Value::symbol("sym1")); - placeholders.insert(500, Value::symbol("sym2")); - placeholders.insert(0, Value::symbol("SimpleValue")); - let v_bytes_1 = Codec::new(placeholders.clone()).encode_bytes(&w).unwrap(); - let mut v_bytes_2 = Vec::new(); - v.serialize(&mut crate::ser::Serializer::new(&mut v_bytes_2, Some(&crate::value::invert_map(&placeholders)))).unwrap(); - println!("== w bytes = {:?}", v_bytes_1); - println!("== v bytes = {:?}", v_bytes_2); - assert_eq!(v_bytes_1, v_bytes_2); - } } diff --git a/implementations/rust/src/ser.rs b/implementations/rust/src/ser.rs index 3c44771..9a920de 100644 --- a/implementations/rust/src/ser.rs +++ b/implementations/rust/src/ser.rs @@ -1,6 +1,6 @@ use serde::Serialize; use super::value::writer::Writer; -use super::value::{Value, EncodePlaceholderMap, Encoder}; +use super::value::Encoder; pub use super::error::Error; type Result = std::result::Result; @@ -8,19 +8,11 @@ type Result = std::result::Result; #[derive(Debug)] pub struct Serializer<'a, W: Writer> { pub write: &'a mut W, - placeholders: Option<&'a EncodePlaceholderMap>, } impl<'a, W: Writer> Serializer<'a, W> { - pub fn new(write: &'a mut W, placeholders: Option<&'a EncodePlaceholderMap>) -> Self { - Serializer { write, placeholders } - } - - fn write_symbol(&mut self, s: &str) -> Result<()> { - match self.placeholders.as_ref().and_then(|m| m.get(&Value::symbol(s))) { - Some(&n) => Ok(self.write.write_placeholder_ref(n)?), - None => Ok(self.write.write_symbol(s)?), - } + pub fn new(write: &'a mut W) -> Self { + Serializer { write } } } @@ -87,7 +79,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { fn serialize_char(self, v: char) -> Result<()> { self.write.open_record(1)?; - self.write_symbol("UnicodeScalar")?; + self.write.write_symbol("UnicodeScalar")?; self.write.write_u32(v as u32)?; Ok(self.write.close_record()?) } @@ -102,26 +94,26 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { fn serialize_none(self) -> Result<()> { self.write.open_record(0)?; - self.write_symbol("None")?; + self.write.write_symbol("None")?; Ok(self.write.close_record()?) } fn serialize_some(self, v: &T) -> Result<()> where T: Serialize { self.write.open_record(1)?; - self.write_symbol("Some")?; + self.write.write_symbol("Some")?; v.serialize(&mut *self)?; Ok(self.write.close_record()?) } fn serialize_unit(self) -> Result<()> { self.write.open_record(0)?; - self.write_symbol("tuple")?; + self.write.write_symbol("tuple")?; Ok(self.write.close_record()?) } fn serialize_unit_struct(self, name: &'static str) -> Result<()> { self.write.open_record(0)?; - self.write_symbol(name)?; + self.write.write_symbol(name)?; Ok(self.write.close_record()?) } @@ -132,7 +124,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { Result<()> { self.write.open_record(0)?; - self.write_symbol(variant_name)?; + self.write.write_symbol(variant_name)?; Ok(self.write.close_record()?) } @@ -141,13 +133,13 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { { match super::value::magic::receive_output_value(name, value) { Some(v) => { - Encoder::new(self.write, self.placeholders).write(&v)?; + Encoder::new(self.write).write(&v)?; Ok(()) } None => { // TODO: This is apparently discouraged, and we should apparently just serialize `value`? self.write.open_record(1)?; - self.write_symbol(name)?; + self.write.write_symbol(name)?; value.serialize(&mut *self)?; Ok(self.write.close_record()?) } @@ -162,7 +154,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { Result<()> where T: Serialize { self.write.open_record(1)?; - self.write_symbol(variant_name)?; + self.write.write_symbol(variant_name)?; value.serialize(&mut *self)?; Ok(self.write.close_record()?) } @@ -177,7 +169,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { fn serialize_tuple(self, count: usize) -> Result { self.write.open_record(count)?; - self.write_symbol("tuple")?; + self.write.write_symbol("tuple")?; Ok(SerializeCompound { ser: self, count: Some(count) }) } @@ -185,7 +177,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { Result { self.write.open_record(count)?; - self.write_symbol(name)?; + self.write.write_symbol(name)?; Ok(SerializeCompound { ser: self, count: Some(count) }) } @@ -197,7 +189,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { Result { self.write.open_record(count)?; - self.write_symbol(variant_name)?; + self.write.write_symbol(variant_name)?; Ok(SerializeCompound { ser: self, count: Some(count) }) } @@ -211,7 +203,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { fn serialize_struct(self, name: &'static str, count: usize) -> Result { self.write.open_record(count)?; - self.write_symbol(name)?; + self.write.write_symbol(name)?; Ok(SerializeCompound { ser: self, count: Some(count) }) } @@ -223,7 +215,7 @@ impl<'a, 'b, W: Writer> serde::Serializer for &'a mut Serializer<'b, W> { Result { self.write.open_record(count)?; - self.write_symbol(variant_name)?; + self.write.write_symbol(variant_name)?; Ok(SerializeCompound { ser: self, count: Some(count) }) } } @@ -349,6 +341,6 @@ impl<'a, 'b, W: Writer> serde::ser::SerializeSeq for SerializeCompound<'a, 'b, W } pub fn to_writer(write: &mut W, value: &T) -> Result<()> { - let mut ser: Serializer<'_, W> = Serializer::new(write, None); + let mut ser: Serializer<'_, W> = Serializer::new(write); value.serialize(&mut ser) } diff --git a/implementations/rust/src/set.rs b/implementations/rust/src/set.rs new file mode 100644 index 0000000..0286a3b --- /dev/null +++ b/implementations/rust/src/set.rs @@ -0,0 +1,24 @@ +use crate::value::{self, to_value, IOValue, UnwrappedIOValue}; +use std::iter::IntoIterator; +use serde::{Serialize, Serializer, Deserialize, Deserializer}; + +pub fn serialize(s: T, serializer: S) -> Result +where + S: Serializer, + T: IntoIterator, + Item: Serialize, +{ + let s = s.into_iter() + .map(|item| to_value(item)) + .collect::>(); + UnwrappedIOValue::from(s).wrap().serialize(serializer) +} + +pub fn deserialize<'de, D, T>(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: Deserialize<'de>, +{ + // Relies on the way we hack around serde's data model in de.rs and value/de.rs. + T::deserialize(deserializer) +} diff --git a/implementations/rust/src/symbol.rs b/implementations/rust/src/symbol.rs index ea7d14e..871454e 100644 --- a/implementations/rust/src/symbol.rs +++ b/implementations/rust/src/symbol.rs @@ -1,17 +1,17 @@ -use crate::value::{Value, PlainValue, NestedValue, NullDomain}; +use crate::value::{IOValue, UnwrappedIOValue, NestedValue}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub struct Symbol(pub String); impl serde::Serialize for Symbol { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { - Value::, NullDomain>::symbol(&self.0).wrap().serialize(serializer) + UnwrappedIOValue::symbol(&self.0).wrap().serialize(serializer) } } impl<'de> serde::Deserialize<'de> for Symbol { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de> { - let v = PlainValue::::deserialize(deserializer)?; + let v = IOValue::deserialize(deserializer)?; let s = v.value().as_symbol().ok_or_else(|| serde::de::Error::custom("Expected symbol"))?; Ok(Symbol(s.clone())) } diff --git a/implementations/rust/src/value/codec.rs b/implementations/rust/src/value/codec.rs deleted file mode 100644 index 0056a55..0000000 --- a/implementations/rust/src/value/codec.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::io::{Read, Write, Error}; -use super::{ - decoder::{self, Decoder, DecodePlaceholderMap}, - encoder::{Encoder, EncodePlaceholderMap}, - invert_map, - reader::{BinaryReader, is_eof_error}, - value::IOValue, -}; - -pub struct Codec { - pub decode_placeholders: Option, - pub encode_placeholders: Option, -} - -impl Codec { - pub fn new(decode_placeholders: DecodePlaceholderMap) -> Self { - let encode_placeholders = Some(invert_map(&decode_placeholders)); - Codec { decode_placeholders: Some(decode_placeholders), encode_placeholders } - } - - pub fn without_placeholders() -> Self { - Codec { decode_placeholders: None, encode_placeholders: None } - } - - pub fn decoder<'a, 'r, R: Read>(&'a self, read: &'r mut R) -> Decoder<'a, BinaryReader<'r, R>> { - Decoder::new(BinaryReader::new(read), self.decode_placeholders.as_ref()) - } - - pub fn encoder<'a, 'w, W: Write>(&'a self, write: &'w mut W) -> Encoder<'w, 'a, W> { - Encoder::new(write, self.encode_placeholders.as_ref()) - } - - pub fn decode_all<'r, R: Read>(&self, read: &'r mut R) -> decoder::Result> { - let mut r = BinaryReader::new(read); - let vs: Vec = Decoder::new(&mut r, self.decode_placeholders.as_ref()) - .collect::>>()?; - match r.peek() { - Err(e) if is_eof_error(&e) => Ok(vs), - Err(e) => Err(e), - Ok(_) => Err(Error::new(std::io::ErrorKind::Other, "trailing bytes")), - } - } - - pub fn encode_bytes(&self, v: &IOValue) -> std::io::Result> { - let mut buf: Vec = Vec::new(); - self.encoder(&mut buf).write(v)?; - Ok(buf) - } -} diff --git a/implementations/rust/src/value/constants.rs b/implementations/rust/src/value/constants.rs index 5cfa039..6cca5f1 100644 --- a/implementations/rust/src/value/constants.rs +++ b/implementations/rust/src/value/constants.rs @@ -19,6 +19,12 @@ impl From for std::io::Error { } } +impl From for crate::error::Error { + fn from(v: InvalidOp) -> Self { + crate::error::Error::Io(v.into()) + } +} + impl TryFrom for Op { type Error = InvalidOp; fn try_from(v: u8) -> Result { diff --git a/implementations/rust/src/value/de.rs b/implementations/rust/src/value/de.rs index e38329d..0049559 100644 --- a/implementations/rust/src/value/de.rs +++ b/implementations/rust/src/value/de.rs @@ -1,60 +1,11 @@ -use crate::value::{Value, NestedValue, IOValue, UnwrappedIOValue, Map}; use crate::value::value::{Float, Double}; +use crate::value::{Value, NestedValue, IOValue, UnwrappedIOValue, Map}; +use crate::error::{Error, ExpectedKind, Received}; use num::traits::cast::ToPrimitive; use serde::Deserialize; use serde::de::{Visitor, SeqAccess, MapAccess, EnumAccess, VariantAccess, DeserializeSeed}; -use std::convert::TryFrom; use std::iter::Iterator; -pub mod error { - use num::bigint::BigInt; - use crate::value::IOValue; - - #[derive(Debug)] - pub enum Error { - Message(String), - InvalidUnicodeScalar(u32), - NumberTooLarge(BigInt), - CannotDeserializeAny, - Expected(ExpectedKind, IOValue), - } - - #[derive(Debug)] - pub enum ExpectedKind { - Boolean, - Float, - Double, - - SignedInteger, - String, - ByteString, - Symbol, - - Record(Option), - SimpleRecord(&'static str, Option), - Option, - Sequence, - Dictionary, - } - - impl serde::de::Error for Error { - fn custom(msg: T) -> Self { - Self::Message(msg.to_string()) - } - } - - impl std::error::Error for Error {} - - impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } - } -} - -pub use error::Error; -use error::ExpectedKind; - pub type Result = std::result::Result; pub struct Deserializer<'de> { @@ -76,7 +27,8 @@ impl<'de> Deserializer<'de> { fn check<'a, T, F>(&'a mut self, f: F, k: ExpectedKind) -> Result where F: FnOnce(&'de UnwrappedIOValue) -> Option { - f(self.input.value()).ok_or_else(|| Error::Expected(k, self.input.clone())) + f(self.input.value()).ok_or_else( + || Error::Expected(k, Received::ReceivedOtherValue(self.input.clone()))) } } @@ -94,7 +46,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> Value::SignedInteger(ref i) => match i.to_i64() { None => match i.to_u64() { - None => Err(Error::NumberTooLarge(i.clone())), + None => Err(Error::NumberOutOfRange(i.clone())), Some(n) => visitor.visit_u64(n), } Some(n) => visitor.visit_i64(n), @@ -109,11 +61,11 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> } else if v.is_simple_record("None", Some(0)) || v.is_simple_record("Some", Some(1)) { self.deserialize_option(visitor) } else if v.is_simple_record("tuple", None) { - visitor.visit_seq(VecSeq::new(self, v.as_simple_record("tuple", None).unwrap())) + visitor.visit_seq(VecSeq::new(self, v.as_simple_record("tuple", None).unwrap().iter())) } else { Err(Error::CannotDeserializeAny) } - Value::Sequence(ref v) => visitor.visit_seq(VecSeq::new(self, v)), + Value::Sequence(ref v) => visitor.visit_seq(VecSeq::new(self, v.iter())), Value::Dictionary(ref d) => visitor.visit_map(DictMap::new(self, d)), _ => Err(Error::CannotDeserializeAny), } @@ -121,79 +73,68 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> fn deserialize_bool(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_bool(self.check(|v| v.as_boolean(), ExpectedKind::Boolean)?) + visitor.visit_bool(self.input.value().to_boolean()?) } fn deserialize_i8(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_i8(i.to_i8().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_i8(self.input.value().to_i8()?) } fn deserialize_i16(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_i16(i.to_i16().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_i16(self.input.value().to_i16()?) } fn deserialize_i32(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_i32(i.to_i32().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_i32(self.input.value().to_i32()?) } fn deserialize_i64(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_i64(i.to_i64().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_i64(self.input.value().to_i64()?) } fn deserialize_u8(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_u8(i.to_u8().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_u8(self.input.value().to_u8()?) } fn deserialize_u16(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_u16(i.to_u16().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_u16(self.input.value().to_u16()?) } fn deserialize_u32(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_u32(i.to_u32().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_u32(self.input.value().to_u32()?) } fn deserialize_u64(self, visitor: V) -> Result where V: Visitor<'de> { - let i = self.check(|v| v.as_signedinteger(), ExpectedKind::SignedInteger)?; - visitor.visit_u64(i.to_u64().ok_or_else(|| Error::NumberTooLarge(i.clone()))?) + visitor.visit_u64(self.input.value().to_u64()?) } fn deserialize_f32(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_f32(self.check(|v| v.as_float(), ExpectedKind::Float)?) + visitor.visit_f32(self.input.value().to_float()?) } fn deserialize_f64(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_f64(self.check(|v| v.as_double(), ExpectedKind::Double)?) + visitor.visit_f64(self.input.value().to_double()?) } fn deserialize_char(self, visitor: V) -> Result where V: Visitor<'de> { - let fs = self.check(|v| v.as_simple_record("UnicodeScalar", Some(1)), - ExpectedKind::SimpleRecord("UnicodeScalar", Some(1)))?; - let c = fs[0].value().as_u32() - .ok_or_else(|| Error::Expected(ExpectedKind::SignedInteger, self.input.copy_via_id()))?; - visitor.visit_char(char::try_from(c).or(Err(Error::InvalidUnicodeScalar(c)))?) + visitor.visit_char(self.input.value().to_char()?) } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_borrowed_str(&self.check(|v| v.as_string(), ExpectedKind::String)?) + let s: &'de str = &self.input.value().to_string()?; + visitor.visit_borrowed_str(s) } fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de> @@ -203,55 +144,46 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_borrowed_bytes(&self.check(|v| v.as_bytestring(), ExpectedKind::ByteString)?) + let bs: &'de [u8] = &self.input.value().to_bytestring()?; + visitor.visit_borrowed_bytes(bs) } fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_byte_buf(self.check(|v| v.as_bytestring(), ExpectedKind::ByteString)?.clone()) + visitor.visit_byte_buf(self.input.value().to_bytestring()?.clone()) } fn deserialize_option(self, visitor: V) -> Result where V: Visitor<'de> { - match self.input.value().as_simple_record("None", Some(0)) { - Some(_fs) => visitor.visit_none(), - None => match self.input.value().as_simple_record("Some", Some(1)) { - Some(fs) => { - self.input = &fs[0]; - visitor.visit_some(self) - } - None => Err(Error::Expected(ExpectedKind::Option, self.input.copy_via_id())) + match self.input.value().to_option()? { + None => visitor.visit_none(), + Some(v) => { + self.input = v; + visitor.visit_some(self) } } } fn deserialize_unit(self, visitor: V) -> Result where V: Visitor<'de> { - if self.input.value().is_simple_record("tuple", Some(0)) { - visitor.visit_unit() - } else { - Err(Error::Expected(ExpectedKind::SimpleRecord("tuple", Some(0)), self.input.copy_via_id())) - } + let _fs = self.input.value().to_simple_record("tuple", Some(0))?; + visitor.visit_unit() } fn deserialize_unit_struct(self, name: &'static str, visitor: V) -> Result where V: Visitor<'de> { - if self.input.value().is_simple_record(name, Some(0)) { - visitor.visit_unit() - } else { - Err(Error::Expected(ExpectedKind::SimpleRecord(name, Some(0)), self.input.copy_via_id())) - } + let _fs = self.input.value().to_simple_record(name, Some(0))?; + visitor.visit_unit() } fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result where V: Visitor<'de> { - match super::magic::transmit_input_value(name, || self.input.clone()) { + match super::magic::transmit_input_value(name, || Ok(self.input.clone()))? { Some(v) => visitor.visit_u64(v), None => { - let fs = self.check(|v| v.as_simple_record(name, Some(1)), - ExpectedKind::SimpleRecord(name, Some(1)))?; + let fs = self.input.value().to_simple_record(name, Some(1))?; self.input = &fs[0]; visitor.visit_newtype_struct(self) } @@ -259,27 +191,32 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de> { - let vs = self.check(|v| v.as_sequence(), ExpectedKind::Sequence)?; - visitor.visit_seq(VecSeq::new(self, vs)) + match self.input.value().as_sequence() { + Some(vs) => visitor.visit_seq(VecSeq::new(self, vs.iter())), + None => { + // Hack around serde's model: Deserialize *sets* as + // sequences, too, and reconstruct them as Rust Sets + // on the visitor side. + visitor.visit_seq(VecSeq::new(self, self.input.value().to_set()?.iter())) + } + } } fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de> { - let fs = self.check(|v| v.as_simple_record("tuple", Some(len)), - ExpectedKind::SimpleRecord("tuple", Some(len)))?; - visitor.visit_seq(VecSeq::new(self, fs)) + let fs = self.input.value().to_simple_record("tuple", Some(len))?; + visitor.visit_seq(VecSeq::new(self, fs.iter())) } fn deserialize_tuple_struct(self, name: &'static str, len: usize, visitor: V) -> Result where V: Visitor<'de> { - let fs = self.check(|v| v.as_simple_record(name, Some(len)), - ExpectedKind::SimpleRecord(name, Some(len)))?; - visitor.visit_seq(VecSeq::new(self, fs)) + let fs = self.input.value().to_simple_record(name, Some(len))?; + visitor.visit_seq(VecSeq::new(self, fs.iter())) } fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de> { - let d = self.check(|v| v.as_dictionary(), ExpectedKind::Dictionary)?; + let d = self.input.value().to_dictionary()?; visitor.visit_map(DictMap::new(self, d)) } @@ -289,9 +226,8 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> visitor: V) -> Result where V: Visitor<'de> { - let fs = self.check(|v| v.as_simple_record(name, Some(fields.len())), - ExpectedKind::SimpleRecord(name, Some(fields.len())))?; - visitor.visit_seq(VecSeq::new(self, fs)) + let fs = self.input.value().to_simple_record(name, Some(fields.len()))?; + visitor.visit_seq(VecSeq::new(self, fs.iter())) } fn deserialize_enum(self, @@ -305,7 +241,7 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> fn deserialize_identifier(self, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_borrowed_str(&self.check(|v| v.as_symbol(), ExpectedKind::Symbol)?) + visitor.visit_str(&self.input.value().to_symbol()?) } fn deserialize_ignored_any(self, visitor: V) -> Result where V: Visitor<'de> @@ -314,32 +250,32 @@ impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> } } -pub struct VecSeq<'a, 'de: 'a> { - index: usize, - vec: &'de [IOValue], +pub struct VecSeq<'a, 'de: 'a, I: Iterator> { + iter: I, de: &'a mut Deserializer<'de>, } -impl<'de, 'a> VecSeq<'a, 'de> { - fn new(de: &'a mut Deserializer<'de>, vec: &'de [IOValue]) -> Self { - VecSeq { index: 0, vec, de } +impl<'de, 'a, I: Iterator> VecSeq<'a, 'de, I> { + fn new(de: &'a mut Deserializer<'de>, iter: I) -> Self { + VecSeq { iter, de } } } -impl<'de, 'a> SeqAccess<'de> for VecSeq<'a, 'de> { +impl<'de, 'a, I: Iterator> SeqAccess<'de> for VecSeq<'a, 'de, I> { type Error = Error; - fn next_element_seed(&mut self, seed: T) - -> Result> where T: DeserializeSeed<'de> + fn next_element_seed(&mut self, seed: T) -> + Result> + where + T: DeserializeSeed<'de> { - if self.index == self.vec.len() { - return Ok(None) + match self.iter.next() { + None => Ok(None), + Some(v) => { + self.de.input = v; + Ok(Some(seed.deserialize(&mut *self.de)?)) + } } - - self.de.input = &self.vec[self.index]; - self.index += 1; - let value = seed.deserialize(&mut *self.de)?; - Ok(Some(value)) } } @@ -409,12 +345,12 @@ impl<'a, 'de> VariantAccess<'de> for &'a mut Deserializer<'de> { } fn tuple_variant(self, _len: usize, visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_seq(VecSeq::new(self, &self.input.value().as_record(None).unwrap().1)) + visitor.visit_seq(VecSeq::new(self, self.input.value().as_record(None).unwrap().1.iter())) } fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result where V: Visitor<'de> { - visitor.visit_seq(VecSeq::new(self, &self.input.value().as_record(Some(fields.len())).unwrap().1)) + visitor.visit_seq(VecSeq::new(self, self.input.value().as_record(Some(fields.len())).unwrap().1.iter())) } } diff --git a/implementations/rust/src/value/decoder.rs b/implementations/rust/src/value/decoder.rs index ba4ae48..0b0ba1c 100644 --- a/implementations/rust/src/value/decoder.rs +++ b/implementations/rust/src/value/decoder.rs @@ -1,40 +1,49 @@ -use super::reader::{Reader, is_eof_error}; +use std::borrow::Cow; +use std::marker::PhantomData; +use super::reader::{self, Reader, BinaryReader}; use super::value::IOValue; -pub use super::reader::{Result, DecodePlaceholderMap}; +pub use super::reader::IOResult as Result; -pub struct Decoder<'a, R: Reader> { +pub struct Decoder<'de, R: Reader<'de>> { pub read: R, - placeholders: Option<&'a DecodePlaceholderMap>, read_annotations: bool, + phantom: PhantomData<&'de ()>, } -impl<'a, R: Reader> Decoder<'a, R> { - pub fn new(read: R, placeholders: Option<&'a DecodePlaceholderMap>) -> - Self - { - Decoder { - read, - placeholders, - read_annotations: true, - } +pub fn from_bytes<'de>(bytes: &'de [u8]) -> + Decoder<'de, BinaryReader<'de, reader::BytesBinarySource<'de>>> +{ + Decoder::new(reader::from_bytes(bytes)) +} + +pub fn from_read<'de, 'a, IOR: std::io::Read>(read: &'a mut IOR) -> + Decoder<'de, BinaryReader<'de, reader::IOBinarySource<'a, IOR>>> +{ + Decoder::new(reader::from_read(read)) +} + +impl<'de, R: Reader<'de>> Decoder<'de, R> { + pub fn new(read: R) -> Self { + Decoder { read, read_annotations: true, phantom: PhantomData } } pub fn set_read_annotations(&mut self, read_annotations: bool) { self.read_annotations = read_annotations } - pub fn next_or_err(&mut self) -> Result { - self.read.next(self.placeholders, self.read_annotations) + pub fn demand_next(&mut self) -> Result> { + self.read.demand_next(self.read_annotations) } } -impl<'a, R: Reader> std::iter::Iterator for Decoder<'a, R> { +impl<'de, R: Reader<'de>> std::iter::Iterator for Decoder<'de, R> { type Item = Result; fn next(&mut self) -> Option { - match self.next_or_err() { - Err(e) if is_eof_error(&e) => None, - other => Some(other) + match self.read.next(self.read_annotations) { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(Some(v)) => Some(Ok(v.into_owned())), } } } diff --git a/implementations/rust/src/value/encoder.rs b/implementations/rust/src/value/encoder.rs index b666b79..03eab97 100644 --- a/implementations/rust/src/value/encoder.rs +++ b/implementations/rust/src/value/encoder.rs @@ -1,18 +1,21 @@ -use super::value::{Value, NestedValue, Domain, IOValue, UnwrappedIOValue, Float, Double, Map}; +use super::value::{Value, NestedValue, Domain, IOValue, UnwrappedIOValue, Float, Double}; use super::writer::Writer; pub use super::writer::Result; -pub type EncodePlaceholderMap = Map; - -pub struct Encoder<'a, 'b, W: Writer> { +pub struct Encoder<'a, W: Writer> { pub write: &'a mut W, - pub placeholders: Option<&'b EncodePlaceholderMap>, } -impl<'a, 'b, W: Writer> Encoder<'a, 'b, W> { - pub fn new(write: &'a mut W, placeholders: Option<&'b EncodePlaceholderMap>) -> Self { - Encoder{ write, placeholders } +pub fn encode_bytes(v: &IOValue) -> std::io::Result> { + let mut buf: Vec = Vec::new(); + Encoder::new(&mut buf).write(v)?; + Ok(buf) +} + +impl<'a, W: Writer> Encoder<'a, W> { + pub fn new(write: &'a mut W) -> Self { + Encoder { write } } pub fn write(&mut self, v: &IOValue) -> Result { @@ -24,39 +27,36 @@ impl<'a, 'b, W: Writer> Encoder<'a, 'b, W> { } pub fn write_value(&mut self, v: &UnwrappedIOValue) -> Result { - match self.placeholders.and_then(|m| m.get(v)) { - Some(&n) => self.write.write_placeholder_ref(n), - None => match v { - Value::Boolean(b) => self.write.write_bool(*b), - Value::Float(Float(f)) => self.write.write_f32(*f), - Value::Double(Double(d)) => self.write.write_f64(*d), - Value::SignedInteger(ref b) => self.write.write_int(b), - Value::String(ref s) => self.write.write_string(s), - Value::ByteString(ref bs) => self.write.write_bytes(bs), - Value::Symbol(ref s) => self.write.write_symbol(s), - Value::Record((ref l, ref fs)) => { - self.write.open_record(fs.len())?; - self.write(IOValue::boxunwrap(l))?; - for f in fs { self.write(f)?; } - self.write.close_record() - } - Value::Sequence(ref vs) => { - self.write.open_sequence(vs.len())?; - for v in vs { self.write(v)?; } - self.write.close_sequence() - } - Value::Set(ref vs) => { - self.write.open_set(vs.len())?; - for v in vs { self.write(v)?; } - self.write.close_set() - } - Value::Dictionary(ref vs) => { - self.write.open_dictionary(vs.len())?; - for (k, v) in vs { self.write(k)?; self.write(v)?; } - self.write.close_dictionary() - } - Value::Domain(ref d) => self.write(&d.as_preserves()?) + match v { + Value::Boolean(b) => self.write.write_bool(*b), + Value::Float(Float(f)) => self.write.write_f32(*f), + Value::Double(Double(d)) => self.write.write_f64(*d), + Value::SignedInteger(ref b) => self.write.write_int(b), + Value::String(ref s) => self.write.write_string(s), + Value::ByteString(ref bs) => self.write.write_bytes(bs), + Value::Symbol(ref s) => self.write.write_symbol(s), + Value::Record((ref l, ref fs)) => { + self.write.open_record(fs.len())?; + self.write(IOValue::boxunwrap(l))?; + for f in fs { self.write(f)?; } + self.write.close_record() } + Value::Sequence(ref vs) => { + self.write.open_sequence(vs.len())?; + for v in vs { self.write(v)?; } + self.write.close_sequence() + } + Value::Set(ref vs) => { + self.write.open_set(vs.len())?; + for v in vs { self.write(v)?; } + self.write.close_set() + } + Value::Dictionary(ref vs) => { + self.write.open_dictionary(vs.len())?; + for (k, v) in vs { self.write(k)?; self.write(v)?; } + self.write.close_dictionary() + } + Value::Domain(ref d) => self.write(&d.as_preserves()?) } } } diff --git a/implementations/rust/src/value/magic.rs b/implementations/rust/src/value/magic.rs index bafbf34..1627b94 100644 --- a/implementations/rust/src/value/magic.rs +++ b/implementations/rust/src/value/magic.rs @@ -1,6 +1,4 @@ -use super::value::{ - IOValue, -}; +use super::value::IOValue; pub static MAGIC: &str = "$____Preserves_Serde_Magic"; @@ -16,20 +14,22 @@ impl<'de> serde::de::Visitor<'de> for IOValueVisitor { } } +#[inline] pub fn output_value(serializer: S, v: IOValue) -> Result { serializer.serialize_newtype_struct(MAGIC, &(Box::into_raw(Box::new(v)) as u64)) } -pub fn input_value<'de, D: serde::Deserializer<'de>>(deserializer: D) -> - Result +#[inline] +pub fn input_value<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result { deserializer.deserialize_newtype_struct(MAGIC, IOValueVisitor) } //--------------------------------------------------------------------------- +#[inline] pub fn receive_output_value(name: &'static str, magic_value: &T) -> Option { if name == MAGIC { let b = unsafe { Box::from_raw(*((magic_value as *const T) as *const u64) as *mut IOValue) }; @@ -40,13 +40,14 @@ pub fn receive_output_value(name: &'static str, magic_value: &T) -> O } } -pub fn transmit_input_value(name: &'static str, f: F) -> Option -where F: FnOnce() -> IOValue +#[inline] +pub fn transmit_input_value(name: &'static str, f: F) -> Result, crate::error::Error> +where F: FnOnce() -> Result { if name == MAGIC { - let b: Box = Box::new(f()); - Some(Box::into_raw(b) as u64) + let b: Box = Box::new(f()?); + Ok(Some(Box::into_raw(b) as u64)) } else { - None + Ok(None) } } diff --git a/implementations/rust/src/value/mod.rs b/implementations/rust/src/value/mod.rs index b437198..69cd27b 100644 --- a/implementations/rust/src/value/mod.rs +++ b/implementations/rust/src/value/mod.rs @@ -1,4 +1,3 @@ -pub mod codec; pub mod constants; pub mod de; pub mod decoder; @@ -10,17 +9,12 @@ pub mod writer; pub mod magic; -pub use codec::Codec; pub use de::Deserializer; pub use de::from_value; pub use decoder::Decoder; -pub use encoder::EncodePlaceholderMap; pub use encoder::Encoder; pub use reader::BinaryReader; -pub use reader::DecodePlaceholderMap; pub use reader::Reader; -pub use reader::is_eof_error; -pub use reader::is_syntax_error; pub use ser::Serializer; pub use ser::to_value; pub use value::AnnotatedValue; @@ -37,6 +31,10 @@ pub use value::UnwrappedIOValue; pub use value::Value; pub use writer::Writer; +pub use value::FALSE; +pub use value::TRUE; +pub use value::EMPTY_SEQ; + pub fn invert_map(m: &Map) -> Map where A: Clone, B: Clone + Ord { diff --git a/implementations/rust/src/value/reader.rs b/implementations/rust/src/value/reader.rs index a1f9799..a1c6a24 100644 --- a/implementations/rust/src/value/reader.rs +++ b/implementations/rust/src/value/reader.rs @@ -1,216 +1,365 @@ use num::bigint::BigInt; +use num::traits::cast::{FromPrimitive, ToPrimitive}; +use std::borrow::Cow; use std::convert::TryFrom; use std::convert::TryInto; use std::io::{Read, Error}; +use std::marker::PhantomData; use super::constants::{Op, InvalidOp, AtomMinor, CompoundMinor}; -use super::value::{Value, NestedValue, IOValue, UnwrappedIOValue, Map, Set}; +use super::value::{Value, NestedValue, IOValue, FALSE, TRUE, Map, Set}; +use crate::error::{self, ExpectedKind, Received, is_eof_io_error, io_eof, io_syntax_error}; -pub type Result = std::result::Result; +pub type IOResult = std::result::Result; +pub type ReaderResult = std::result::Result; #[derive(Debug)] -enum PeekState { - Eof, - Empty, - Full(u8), +pub struct CompoundBody { + minor: CompoundMinor, + limit: CompoundLimit, } -pub type DecodePlaceholderMap = Map; - -pub trait Reader { - fn next( - &mut self, - placeholders: Option<&DecodePlaceholderMap>, - read_annotations: bool, - ) -> Result; +#[derive(Debug)] +pub enum CompoundLimit { + Counted(usize), + Streaming, } -impl<'re, R: Reader> Reader for &'re mut R { - fn next( - &mut self, - placeholders: Option<&DecodePlaceholderMap>, - read_annotations: bool, - ) -> Result { - (*self).next(placeholders, read_annotations) +impl CompoundBody { + pub fn counted(minor: CompoundMinor, size: usize) -> Self { + CompoundBody { minor, limit: CompoundLimit::Counted(size) } } -} -pub struct BinaryReader<'a, R: Read> { - read: &'a mut R, - buf: PeekState, -} - -struct ConfiguredBinaryReader<'de, 'pl, 'a, R: Read> { - reader: &'de mut BinaryReader<'a, R>, - placeholders: Option<&'pl DecodePlaceholderMap>, - read_annotations: bool, -} - -struct CountedStream<'de, 'pl, 'a, R: Read> { - reader: ConfiguredBinaryReader<'de, 'pl, 'a, R>, - count: usize, -} - -impl<'de, 'pl, 'a, R: Read> Iterator for CountedStream<'de, 'pl, 'a, R> -{ - type Item = Result; - fn next(&mut self) -> Option { - if self.count == 0 { return None } - self.count -= 1; - Some(self.reader.reader.next(self.reader.placeholders, self.reader.read_annotations)) + pub fn streaming(minor: CompoundMinor) -> Self { + CompoundBody { minor, limit: CompoundLimit::Streaming } } -} -struct DelimitedStream<'de, 'pl, 'a, R: Read> { - reader: ConfiguredBinaryReader<'de, 'pl, 'a, R>, -} + pub fn more_expected<'de, R: Reader<'de>>(&mut self, read: &mut R) -> ReaderResult { + match self.limit { + CompoundLimit::Counted(ref mut n) => + if *n == 0 { + read.close_compound_counted(self.minor)?; + Ok(false) + } else { + *n = *n - 1; + Ok(true) + }, + CompoundLimit::Streaming => + Ok(!read.close_compound_stream(self.minor)?), + } + } -impl<'de, 'pl, 'a, R: Read> Iterator for DelimitedStream<'de, 'pl, 'a, R> -{ - type Item = Result; - fn next(&mut self) -> Option { - match self.reader.reader.peekend() { - Err(e) => Some(Err(e)), - Ok(true) => None, - Ok(false) => Some(self.reader.reader.next(self.reader.placeholders, self.reader.read_annotations)), + pub fn next_symbol<'de, R: Reader<'de>>(&mut self, read: &mut R) -> ReaderResult>> { + match self.more_expected(read)? { + false => Ok(None), + true => Ok(Some(read.next_symbol()?)), + } + } + + pub fn next_value<'de, R: Reader<'de>>(&mut self, read: &mut R, read_annotations: bool) -> + ReaderResult> + { + match self.more_expected(read)? { + false => Ok(None), + true => Ok(Some(read.demand_next(read_annotations)?.into_owned())), + } + } + + pub fn remainder<'de, R: Reader<'de>>(&mut self, read: &mut R, read_annotations: bool) -> + ReaderResult> + { + let mut result = Vec::new(); + while let Some(v) = self.next_value(read, read_annotations)? { + result.push(v); + } + Ok(result) + } + + pub fn ensure_more_expected<'de, R: Reader<'de>>(&mut self, read: &mut R) -> ReaderResult<()> { + if self.more_expected(read)? { + Ok(()) + } else { + Err(error::Error::MissingItem) + } + } + + pub fn ensure_complete<'de, R: Reader<'de>>(&mut self, read: &mut R) -> ReaderResult<()> { + if self.more_expected(read)? { + Err(error::Error::MissingCloseDelimiter) + } else { + Ok(()) } } } -pub fn decodeop(b: u8) -> Result<(Op, u8)> { - Ok((Op::try_from(b >> 4)?, b & 15)) -} +pub trait Reader<'de> { + fn next(&mut self, read_annotations: bool) -> IOResult>>; + fn open_record(&mut self, arity: Option) -> ReaderResult; + fn open_sequence_or_set(&mut self) -> ReaderResult; + fn open_sequence(&mut self) -> ReaderResult; + fn open_set(&mut self) -> ReaderResult; + fn open_dictionary(&mut self) -> ReaderResult; + fn close_compound_counted(&mut self, minor: CompoundMinor) -> ReaderResult<()>; + fn close_compound_stream(&mut self, minor: CompoundMinor) -> ReaderResult; -pub fn decodeint(bs: &[u8]) -> BigInt { - BigInt::from_signed_bytes_be(bs) -} + //--------------------------------------------------------------------------- -pub fn decodestr(bs: &[u8]) -> Result<&str> { - std::str::from_utf8(bs).map_err(|_| err("Invalid UTF-8")) -} - -pub fn decodebinary(minor: AtomMinor, bs: Vec) -> Result { - match minor { - AtomMinor::SignedInteger => Ok(Value::from(decodeint(&bs)).wrap()), - AtomMinor::String => Ok(Value::from(decodestr(&bs)?).wrap()), - AtomMinor::ByteString => Ok(Value::ByteString(bs.to_vec()).wrap()), - AtomMinor::Symbol => Ok(Value::symbol(decodestr(&bs)?).wrap()), + fn skip_value(&mut self) -> IOResult<()> { + // TODO efficient skipping in specific impls of this trait + let _ = self.demand_next(false)?; + Ok(()) } -} -pub fn decodecompound>>(minor: CompoundMinor, mut iter: I) -> - Result -{ - match minor { - CompoundMinor::Record => - match iter.next() { - None => Err(err("Too few elements in encoded record")), - Some(labelres) => { - let label = labelres?; - Ok(Value::record(label, iter.collect::>>()?).wrap()) - } - } - CompoundMinor::Sequence => { - Ok(Value::Sequence(iter.collect::>>()?).wrap()) + fn demand_next(&mut self, read_annotations: bool) -> IOResult> { + match self.next(read_annotations)? { + None => Err(io_eof()), + Some(v) => Ok(v) } - CompoundMinor::Set => { - let mut s = Set::new(); - for res in iter { s.insert(res?); } - Ok(Value::Set(s).wrap()) + } + + fn next_boolean(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_boolean() } + fn next_i8(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_i8() } + fn next_u8(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_u8() } + fn next_i16(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_i16() } + fn next_u16(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_u16() } + fn next_i32(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_i32() } + fn next_u32(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_u32() } + fn next_i64(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_i64() } + fn next_u64(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_u64() } + fn next_float(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_float() } + fn next_double(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_double() } + fn next_char(&mut self) -> ReaderResult { self.demand_next(false)?.value().to_char() } + + fn next_str(&mut self) -> ReaderResult> { + Ok(Cow::Owned(self.demand_next(false)?.value().to_string()?.to_owned())) + } + + fn next_bytestring(&mut self) -> ReaderResult> { + Ok(Cow::Owned(self.demand_next(false)?.value().to_bytestring()?.to_owned())) + } + + fn next_symbol(&mut self) -> ReaderResult> { + Ok(Cow::Owned(self.demand_next(false)?.value().to_symbol()?.to_owned())) + } + + fn open_option(&mut self) -> + ReaderResult<(bool, CompoundBody)> + where + Self: Sized + { + let mut compound_body = self.open_record(None)?; + let label: &str = &compound_body.next_symbol(self)?.ok_or(error::Error::MissingItem)?; + match label { + "None" => Ok((false, compound_body)), + "Some" => Ok((true, compound_body)), + _ => Err(error::Error::Expected(ExpectedKind::Option, + Received::ReceivedRecordWithLabel(label.to_owned()))), } - CompoundMinor::Dictionary => { - let mut d = Map::new(); - while let Some(kres) = iter.next() { - let k = kres?; - match iter.next() { - Some(vres) => { - let v = vres?; - d.insert(k, v); - } - None => return Err(err("Missing dictionary value")), - } - } - Ok(Value::Dictionary(d).wrap()) + } + + fn open_simple_record(&mut self, name: &'static str, arity: Option) -> + ReaderResult + where + Self: Sized + { + let mut compound_body = self.open_record(arity)?; + let label: &str = &compound_body.next_symbol(self)?.ok_or(error::Error::MissingItem)?; + if label == name { + Ok(compound_body) + } else { + Err(error::Error::Expected(ExpectedKind::SimpleRecord(name, arity), + Received::ReceivedRecordWithLabel(label.to_owned()))) } } } -pub fn eof() -> Error { - Error::new(std::io::ErrorKind::UnexpectedEof, "EOF") -} +impl<'r, 'de, R: Reader<'de>> Reader<'de> for &'r mut R { + fn next(&mut self, read_annotations: bool) -> IOResult>> { + (*self).next(read_annotations) + } -pub fn err(s: &str) -> Error { - Error::new(std::io::ErrorKind::InvalidData, s) -} + fn open_record(&mut self, arity: Option) -> ReaderResult { + (*self).open_record(arity) + } -pub fn is_syntax_error(e: &Error) -> bool { - match e.kind() { - std::io::ErrorKind::InvalidData => true, - _ => false, + fn open_sequence_or_set(&mut self) -> ReaderResult { + (*self).open_sequence_or_set() + } + + fn open_sequence(&mut self) -> ReaderResult { + (*self).open_sequence() + } + + fn open_set(&mut self) -> ReaderResult { + (*self).open_set() + } + + fn open_dictionary(&mut self) -> ReaderResult { + (*self).open_dictionary() + } + + fn close_compound_counted(&mut self, minor: CompoundMinor) -> ReaderResult<()> { + (*self).close_compound_counted(minor) + } + + fn close_compound_stream(&mut self, minor: CompoundMinor) -> ReaderResult { + (*self).close_compound_stream(minor) } } -pub fn is_eof_error(e: &Error) -> bool { - match e.kind() { - std::io::ErrorKind::UnexpectedEof => true, - _ => false, - } +pub trait BinarySource<'de> { + fn skip(&mut self) -> IOResult<()>; + fn peek(&mut self) -> IOResult; + fn readbytes(&mut self, count: usize) -> IOResult>; + fn readbytes_into(&mut self, bs: &mut [u8]) -> IOResult<()>; } -impl<'a, R: Read> BinaryReader<'a, R> { +pub struct IOBinarySource<'a, R: Read> { + pub read: &'a mut R, + pub buf: Option, +} + +impl<'a, R: Read> IOBinarySource<'a, R> { pub fn new(read: &'a mut R) -> Self { - BinaryReader { - read, - buf: PeekState::Empty, - } + IOBinarySource { read, buf: None } + } +} + +impl<'de, 'a, R: Read> BinarySource<'de> for IOBinarySource<'a, R> { + fn skip(&mut self) -> IOResult<()> { + if let None = self.buf { unreachable!(); } + self.buf = None; + Ok(()) } - fn prime(&mut self) -> Result<()> { - if let PeekState::Empty = self.buf { - let b = &mut [0]; - match self.read.read(b)? { - 0 => self.buf = PeekState::Eof, - 1 => self.buf = PeekState::Full(b[0]), - _ => unreachable!(), + fn peek(&mut self) -> IOResult { + match self.buf { + Some(b) => Ok(b), + None => { + let b = &mut [0]; + match self.read.read(b)? { + 0 => Err(io_eof()), + 1 => { + self.buf = Some(b[0]); + Ok(b[0]) + } + _ => unreachable!(), + } } } + } + + fn readbytes(&mut self, count: usize) -> IOResult> { + if let Some(_) = self.buf { unreachable!(); } + let mut bs = vec![0; count]; + self.read.read_exact(&mut bs)?; + Ok(Cow::Owned(bs)) + } + + fn readbytes_into(&mut self, bs: &mut [u8]) -> IOResult<()> { + if let Some(_) = self.buf { unreachable!(); } + self.read.read_exact(bs) + } +} + +pub struct BytesBinarySource<'de> { + pub bytes: &'de [u8], + pub index: usize, +} + +impl<'de> BytesBinarySource<'de> { + pub fn new(bytes: &'de [u8]) -> Self { + BytesBinarySource { bytes, index: 0 } + } +} + +impl<'de> BinarySource<'de> for BytesBinarySource<'de> { + fn skip(&mut self) -> IOResult<()> { + if self.index >= self.bytes.len() { unreachable!(); } + self.index += 1; Ok(()) } - pub fn skip(&mut self) -> Result<()> { - self.prime()?; - if let PeekState::Full(_) = self.buf { - self.buf = PeekState::Empty; - } - Ok(()) - } - - pub fn peek(&mut self) -> Result { - self.prime()?; - match self.buf { - PeekState::Eof => Err(eof()), - PeekState::Empty => unreachable!(), - PeekState::Full(b) => Ok(b), + fn peek(&mut self) -> IOResult { + if self.index >= self.bytes.len() { + Err(io_eof()) + } else { + Ok(self.bytes[self.index]) } } - pub fn read(&mut self) -> Result { + fn readbytes(&mut self, count: usize) -> IOResult> { + if self.index + count > self.bytes.len() { + Err(io_eof()) + } else { + let bs = &self.bytes[self.index..self.index+count]; + self.index += count; + Ok(Cow::Borrowed(bs)) + } + } + + fn readbytes_into(&mut self, bs: &mut [u8]) -> IOResult<()> { + let count = bs.len(); + if self.index + count > self.bytes.len() { + Err(io_eof()) + } else { + bs.copy_from_slice(&self.bytes[self.index..self.index+count]); + self.index += count; + Ok(()) + } + } +} + +pub struct BinaryReader<'de, S: BinarySource<'de>> { + pub source: S, + phantom: PhantomData<&'de ()>, +} + +impl<'de, S: BinarySource<'de>> BinarySource<'de> for BinaryReader<'de, S> { + fn skip(&mut self) -> IOResult<()> { + self.source.skip() + } + fn peek(&mut self) -> IOResult { + self.source.peek() + } + fn readbytes(&mut self, count: usize) -> IOResult> { + self.source.readbytes(count) + } + fn readbytes_into(&mut self, bs: &mut [u8]) -> IOResult<()> { + self.source.readbytes_into(bs) + } +} + +pub fn from_bytes<'de>(bytes: &'de [u8]) -> + BinaryReader<'de, BytesBinarySource<'de>> +{ + BinaryReader::new(BytesBinarySource::new(bytes)) +} + +pub fn from_read<'de, 'a, IOR: std::io::Read>(read: &'a mut IOR) -> + BinaryReader<'de, IOBinarySource<'a, IOR>> +{ + BinaryReader::new(IOBinarySource::new(read)) +} + +impl<'de, S: BinarySource<'de>> BinaryReader<'de, S> { + pub fn new(source: S) -> Self { + BinaryReader { source, phantom: PhantomData } + } + + fn read(&mut self) -> IOResult { let v = self.peek()?; - if let PeekState::Full(_) = self.buf { - self.buf = PeekState::Empty; - } + self.skip()?; Ok(v) } - pub fn readbytes(&mut self, bs: &mut [u8]) -> Result<()> { - match self.buf { - PeekState::Eof => unreachable!(), - PeekState::Empty => (), - PeekState::Full(_) => unreachable!(), - }; - self.read.read_exact(bs) + fn expected(&mut self, k: ExpectedKind) -> error::Error { + match self.demand_next(true) { + Ok(v) => error::Error::Expected(k, Received::ReceivedOtherValue(v.into_owned())), + Err(e) => e.into() + } } - pub fn varint(&mut self) -> Result { + fn varint(&mut self) -> IOResult { let v = self.read()?; if v < 128 { Ok(usize::from(v)) @@ -219,7 +368,7 @@ impl<'a, R: Read> BinaryReader<'a, R> { } } - pub fn wirelength(&mut self, arg: u8) -> Result { + fn wirelength(&mut self, arg: u8) -> IOResult { if arg < 15 { Ok(usize::from(arg)) } else { @@ -227,7 +376,7 @@ impl<'a, R: Read> BinaryReader<'a, R> { } } - pub fn peekend(&mut self) -> Result { + fn peekend(&mut self) -> IOResult { if self.peek()? == 4 { self.skip()?; Ok(true) @@ -235,98 +384,452 @@ impl<'a, R: Read> BinaryReader<'a, R> { Ok(false) } } -} -impl<'re, 'a, R: Read> Reader for BinaryReader<'a, R> { - fn next( - &mut self, - placeholders: Option<&DecodePlaceholderMap>, - read_annotations: bool - ) -> - Result - { + fn gather_chunks(&mut self) -> IOResult> { + let mut bs = Vec::with_capacity(256); + while !self.peekend()? { + match decodeop(self.peek()?)? { + (Op::Atom(AtomMinor::ByteString), arg) => { + self.skip()?; + let count = self.wirelength(arg)?; + if count == 0 { + return Err(io_syntax_error("Empty binary chunks are forbidden")); + } + bs.extend_from_slice(&self.readbytes(count)?) + }, + _ => return Err(io_syntax_error("Unexpected non-format-B-ByteString chunk")) + } + } + Ok(bs) + } + + fn peek_next_nonannotation_op(&mut self) -> ReaderResult<(Op, u8)> { loop { - return match decodeop(self.read()?)? { - (Op::Misc(0), 0) => Ok(Value::from(false).wrap()), - (Op::Misc(0), 1) => Ok(Value::from(true).wrap()), - (Op::Misc(0), 2) => { - let mut bs = [0; 4]; - self.readbytes(&mut bs)?; - Ok(Value::from(f32::from_bits(u32::from_be_bytes(bs.try_into().unwrap()))).wrap()) + match decodeop(self.peek()?)? { + (Op::Misc(0), 5) => self.skip()?, + other => return Ok(other), + } + } + } + + fn next_atomic(&mut self, minor: AtomMinor, k: ExpectedKind) -> ReaderResult> { + match self.peek_next_nonannotation_op()? { + (Op::Atom(actual_minor), arg) if actual_minor == minor => { + self.skip()?; + let count = self.wirelength(arg)?; + Ok(self.readbytes(count)?) + } + (Op::Misc(2), arg) => match Op::try_from(arg)? { + Op::Atom(actual_minor) if actual_minor == minor => { + self.skip()?; + Ok(Cow::Owned(self.gather_chunks()?)) } - (Op::Misc(0), 3) => { - let mut bs = [0; 8]; - self.readbytes(&mut bs)?; - Ok(Value::from(f64::from_bits(u64::from_be_bytes(bs.try_into().unwrap()))).wrap()) + _ => Err(self.expected(k)), + }, + _ => Err(self.expected(k)), + } + } + + fn next_compound(&mut self, minor: CompoundMinor, k: ExpectedKind) -> + ReaderResult + { + match self.peek_next_nonannotation_op()? { + (Op::Compound(actual_minor), arg) if actual_minor == minor => { + self.skip()?; + Ok(CompoundBody::counted(minor, self.wirelength(arg)?)) + } + (Op::Misc(2), arg) => match Op::try_from(arg)? { + Op::Compound(actual_minor) if actual_minor == minor => { + self.skip()?; + Ok(CompoundBody::streaming(minor)) } - (Op::Misc(0), 5) => { - if read_annotations { - let mut annotations = vec![self.next(placeholders, read_annotations)?]; - while decodeop(self.peek()?)? == (Op::Misc(0), 5) { - self.skip()?; - annotations.push(self.next(placeholders, read_annotations)?); - } - let v = self.next(placeholders, read_annotations)?; - assert!(v.annotations().is_empty()); - Ok(IOValue::wrap_ann(annotations, v.value_owned())) - } else { - let _ = self.next(placeholders, read_annotations)?; - continue; + _ => Err(self.expected(k)), + }, + _ => Err(self.expected(k)), + } + } + + fn next_unsigned(&mut self, f: F) -> ReaderResult + where + F: FnOnce(u64) -> Option + { + match self.peek_next_nonannotation_op()? { + (Op::Misc(3), arg) => { + self.skip()?; + if arg > 12 { + Err(error::Error::NumberOutOfRange(decodeint(&[((arg as i8) - 16) as u8]))) + } else { + f(arg as u64).ok_or_else(|| error::Error::NumberOutOfRange(decodeint(&[arg]))) + } + } + (Op::Atom(AtomMinor::SignedInteger), arg) => { + self.skip()?; + let mut count = self.wirelength(arg)?; + if count == 0 { + return f(0).ok_or_else(|| error::Error::NumberOutOfRange(BigInt::from(0))); + } + if count > 8 { + let prefix = self.readbytes(count - 8)?; + if !(&prefix).iter().all(|b| *b == 0x00) { + let mut total = prefix.into_owned(); + total.extend_from_slice(&self.readbytes(8)?); + return Err(error::Error::NumberOutOfRange(decodeint(&total))); + } + count = 8; + } else { + if (self.peek()? & 0x80) != 0 { + return Err(error::Error::NumberOutOfRange(decodeint(&self.readbytes(count)?))) } } - (Op::Misc(0), _) => Err(err("Invalid format A encoding")), - (Op::Misc(1), arg) => { - let n = self.wirelength(arg)?; - match placeholders.and_then(|m| m.get(&n)) { - Some(v) => Ok(v.clone().wrap()), - None => Err(err("Invalid Preserves placeholder")), + let mut bs = [0; 8]; + self.readbytes_into(&mut bs[8 - count..])?; + f(u64::from_be_bytes(bs)) + .ok_or_else(|| error::Error::NumberOutOfRange(decodeint(&bs))) + } + _ => { + let i_value = self.demand_next(false)?; + let i = i_value.value().to_signedinteger()?; + let n = i.to_u64().ok_or_else(|| error::Error::NumberOutOfRange(i.clone()))?; + f(n).ok_or_else(|| error::Error::NumberOutOfRange(i.clone())) + } + } + } + + fn next_signed(&mut self, f: F) -> ReaderResult + where + F: FnOnce(i64) -> Option + { + match self.peek_next_nonannotation_op()? { + (Op::Misc(3), arg) => { + self.skip()?; + let n = arg as i64; + let n = if n > 12 { n - 16 } else { n }; + f(n).ok_or_else(|| error::Error::NumberOutOfRange(decodeint(&[n as u8]))) + } + (Op::Atom(AtomMinor::SignedInteger), arg) => { + self.skip()?; + let mut count = self.wirelength(arg)?; + if count == 0 { + return f(0).ok_or_else(|| error::Error::NumberOutOfRange(BigInt::from(0))); + } + let fill_byte = if count > 8 { + let prefix = self.readbytes(count - 8)?; + let fill_byte = if (prefix[0] & 0x80) == 0 { 0x00 } else { 0xff }; + if !(&prefix).iter().all(|b| *b == fill_byte) { + let mut total = prefix.into_owned(); + total.extend_from_slice(&self.readbytes(8)?); + return Err(error::Error::NumberOutOfRange(decodeint(&total))); } - } - (Op::Misc(2), arg) => match Op::try_from(arg)? { - Op::Atom(minor) => { - let mut bs = Vec::with_capacity(256); - while !self.peekend()? { - match self.next(placeholders, false)?.value().as_bytestring() { - Some(chunk) => bs.extend_from_slice(chunk), - None => return Err(err("Unexpected non-binary chunk")), - } - } - decodebinary(minor, bs) + count = 8; + if (self.peek()? & 0x80) != (fill_byte & 0x80) { + let mut total = vec![fill_byte]; + total.extend_from_slice(&self.readbytes(8)?); + return Err(error::Error::NumberOutOfRange(decodeint(&total))); } - Op::Compound(minor) => decodecompound(minor, DelimitedStream { - reader: ConfiguredBinaryReader { - reader: self, - placeholders, - read_annotations, - }, - }), - _ => Err(err("Invalid format C start byte")), - } - (Op::Misc(3), arg) => { - let n = if arg > 12 { i32::from(arg) - 16 } else { i32::from(arg) }; - Ok(Value::from(n).wrap()) - } - (Op::Misc(_), _) => unreachable!(), - (Op::Atom(minor), arg) => { - let count = self.wirelength(arg)?; - let mut bs = vec![0; count]; - self.readbytes(&mut bs)?; - decodebinary(minor, bs) - } - (Op::Compound(minor), arg) => { - let count = self.wirelength(arg)?; - decodecompound(minor, CountedStream { - reader: ConfiguredBinaryReader { - reader: self, - placeholders, - read_annotations, - }, - count, - }) - } - (Op::Reserved(3), 15) => continue, - (Op::Reserved(_), _) => Err(InvalidOp.into()), + fill_byte + } else { + if (self.peek()? & 0x80) == 0 { 0x00 } else { 0xff } + }; + let mut bs = [fill_byte; 8]; + self.readbytes_into(&mut bs[8 - count..])?; + f(i64::from_be_bytes(bs)) + .ok_or_else(|| error::Error::NumberOutOfRange(decodeint(&bs))) + } + _ => { + let i_value = self.demand_next(false)?; + let i = i_value.value().to_signedinteger()?; + let n = i.to_i64().ok_or_else(|| error::Error::NumberOutOfRange(i.clone()))?; + f(n).ok_or_else(|| error::Error::NumberOutOfRange(i.clone())) } } } } + +impl<'de, S: BinarySource<'de>> Reader<'de> for BinaryReader<'de, S> { + fn next(&mut self, read_annotations: bool) -> IOResult>> { + match self.peek() { + Err(e) if is_eof_io_error(&e) => return Ok(None), + Err(e) => return Err(e), + Ok(_) => (), + } + loop { + return Ok(Some(match decodeop(self.read()?)? { + (Op::Misc(0), 0) => Cow::Borrowed(&FALSE), + (Op::Misc(0), 1) => Cow::Borrowed(&TRUE), + (Op::Misc(0), 2) => { + let bs: &[u8] = &self.readbytes(4)?; + Cow::Owned(Value::from(f32::from_bits(u32::from_be_bytes(bs.try_into().unwrap()))).wrap()) + } + (Op::Misc(0), 3) => { + let bs: &[u8] = &self.readbytes(8)?; + Cow::Owned(Value::from(f64::from_bits(u64::from_be_bytes(bs.try_into().unwrap()))).wrap()) + } + (Op::Misc(0), 5) => { + if read_annotations { + let mut annotations = vec![self.demand_next(read_annotations)?.into_owned()]; + while decodeop(self.peek()?)? == (Op::Misc(0), 5) { + self.skip()?; + annotations.push(self.demand_next(read_annotations)?.into_owned()); + } + let (existing_annotations, v) = self.demand_next(read_annotations)?.into_owned().pieces(); + annotations.extend(existing_annotations); + Cow::Owned(IOValue::wrap_ann(annotations, v)) + } else { + self.skip_value()?; + continue; + } + } + (Op::Misc(0), _) => Err(io_syntax_error("Invalid format A encoding"))?, + (Op::Misc(1), _) => Err(io_syntax_error("Invalid format A encoding"))?, + (Op::Misc(2), arg) => match Op::try_from(arg)? { + Op::Atom(minor) => + Cow::Owned(decodebinary(minor, Cow::Owned(self.gather_chunks()?))?), + Op::Compound(minor) => Cow::Owned(decodecompound(minor, DelimitedStream { + reader: ConfiguredBinaryReader { + reader: self, + read_annotations, + phantom: PhantomData, + }, + })?), + _ => Err(io_syntax_error("Invalid format C start byte"))?, + } + (Op::Misc(3), arg) => { + let n = if arg > 12 { i32::from(arg) - 16 } else { i32::from(arg) }; + // TODO: prebuild these in value.rs + Cow::Owned(Value::from(n).wrap()) + } + (Op::Misc(_), _) => unreachable!(), + (Op::Atom(minor), arg) => { + let count = self.wirelength(arg)?; + Cow::Owned(decodebinary(minor, self.readbytes(count)?)?) + } + (Op::Compound(minor), arg) => { + let count = self.wirelength(arg)?; + Cow::Owned(decodecompound(minor, CountedStream { + reader: ConfiguredBinaryReader { + reader: self, + read_annotations, + phantom: PhantomData, + }, + count, + })?) + } + (Op::Reserved(3), 15) => continue, + (Op::Reserved(_), _) => return Err(InvalidOp.into()), + })) + } + } + + fn open_record(&mut self, arity: Option) -> ReaderResult { + if let Some(expected_arity) = arity { + let compound_format = + self.next_compound(CompoundMinor::Record, ExpectedKind::Record(arity))?; + if let CompoundLimit::Counted(count) = compound_format.limit { + if count != expected_arity + 1 /* we add 1 for the label */ { + return Err(error::Error::Expected(ExpectedKind::Record(arity), + Received::ReceivedSomethingElse)); + } + } + Ok(compound_format) + } else { + self.next_compound(CompoundMinor::Record, ExpectedKind::Record(None)) + } + } + + fn open_sequence_or_set(&mut self) -> ReaderResult { + match self.peek_next_nonannotation_op()? { + (Op::Compound(minor), arg) + if CompoundMinor::Sequence == minor || CompoundMinor::Set == minor => { + self.skip()?; + Ok(CompoundBody::counted(minor, self.wirelength(arg)?)) + } + (Op::Misc(2), arg) => match Op::try_from(arg)? { + Op::Compound(minor) + if CompoundMinor::Sequence == minor || CompoundMinor::Set == minor => { + self.skip()?; + Ok(CompoundBody::streaming(minor)) + } + _ => Err(self.expected(ExpectedKind::SequenceOrSet)), + } + _ => Err(self.expected(ExpectedKind::SequenceOrSet)), + } + } + + fn open_sequence(&mut self) -> ReaderResult { + self.next_compound(CompoundMinor::Sequence, ExpectedKind::Sequence) + } + + fn open_set(&mut self) -> ReaderResult { + self.next_compound(CompoundMinor::Set, ExpectedKind::Set) + } + + fn open_dictionary(&mut self) -> ReaderResult { + self.next_compound(CompoundMinor::Dictionary, ExpectedKind::Dictionary) + } + + fn close_compound_counted(&mut self, _minor: CompoundMinor) -> ReaderResult<()> { + // Nothing to do -- no close delimiter to consume + Ok(()) + } + + fn close_compound_stream(&mut self, _minor: CompoundMinor) -> ReaderResult { + Ok(self.peekend()?) + } + + fn next_boolean(&mut self) -> ReaderResult { + match self.peek_next_nonannotation_op()? { + (Op::Misc(0), 0) => { self.skip()?; Ok(false) } + (Op::Misc(0), 1) => { self.skip()?; Ok(true) } + _ => Err(self.expected(ExpectedKind::Boolean)), + } + } + + fn next_i8(&mut self) -> ReaderResult { self.next_signed(|n| n.to_i8()) } + fn next_i16(&mut self) -> ReaderResult { self.next_signed(|n| n.to_i16()) } + fn next_i32(&mut self) -> ReaderResult { self.next_signed(|n| n.to_i32()) } + fn next_i64(&mut self) -> ReaderResult { self.next_signed(|n| Some(n)) } + + fn next_u8(&mut self) -> ReaderResult { self.next_unsigned(|n| n.to_u8()) } + fn next_u16(&mut self) -> ReaderResult { self.next_unsigned(|n| n.to_u16()) } + fn next_u32(&mut self) -> ReaderResult { self.next_unsigned(|n| n.to_u32()) } + fn next_u64(&mut self) -> ReaderResult { self.next_unsigned(|n| Some(n)) } + + fn next_float(&mut self) -> ReaderResult { + match self.peek_next_nonannotation_op()? { + (Op::Misc(0), 2) => { + self.skip()?; + let bs: &[u8] = &self.readbytes(4)?; + Ok(f32::from_bits(u32::from_be_bytes(bs.try_into().unwrap()))) + }, + _ => Err(self.expected(ExpectedKind::Float)), + } + } + + fn next_double(&mut self) -> ReaderResult { + match self.peek_next_nonannotation_op()? { + (Op::Misc(0), 3) => { + self.skip()?; + let bs: &[u8] = &self.readbytes(8)?; + Ok(f64::from_bits(u64::from_be_bytes(bs.try_into().unwrap()))) + }, + _ => Err(self.expected(ExpectedKind::Double)), + } + } + + fn next_str(&mut self) -> ReaderResult> { + Ok(decodestr(self.next_atomic(AtomMinor::String, ExpectedKind::Symbol)?)?) + } + + fn next_bytestring(&mut self) -> ReaderResult> { + self.next_atomic(AtomMinor::ByteString, ExpectedKind::Symbol) + } + + fn next_symbol(&mut self) -> ReaderResult> { + Ok(decodestr(self.next_atomic(AtomMinor::Symbol, ExpectedKind::Symbol)?)?) + } +} + +struct ConfiguredBinaryReader<'de, 'a, S: BinarySource<'de>> { + reader: &'a mut BinaryReader<'de, S>, + read_annotations: bool, + phantom: PhantomData<&'de ()>, +} + +struct CountedStream<'de, 'a, S: BinarySource<'de>> { + reader: ConfiguredBinaryReader<'de, 'a, S>, + count: usize, +} + +impl<'de, 'a, S: BinarySource<'de>> Iterator for CountedStream<'de, 'a, S> +{ + type Item = IOResult>; + fn next(&mut self) -> Option { + if self.count == 0 { return None } + self.count -= 1; + Some(self.reader.reader.demand_next(self.reader.read_annotations)) + } +} + +struct DelimitedStream<'de, 'a, S: BinarySource<'de>> { + reader: ConfiguredBinaryReader<'de, 'a, S>, +} + +impl<'de, 'a, S: BinarySource<'de>> Iterator for DelimitedStream<'de, 'a, S> +{ + type Item = IOResult>; + fn next(&mut self) -> Option { + match self.reader.reader.peekend() { + Err(e) => Some(Err(e)), + Ok(true) => None, + Ok(false) => Some(self.reader.reader.demand_next(self.reader.read_annotations)), + } + } +} + +pub fn decodeop(b: u8) -> IOResult<(Op, u8)> { + Ok((Op::try_from(b >> 4)?, b & 15)) +} + +pub fn decodeint(bs: &[u8]) -> BigInt { + BigInt::from_signed_bytes_be(bs) +} + +pub fn decodestr<'de>(cow: Cow<'de, [u8]>) -> IOResult> { + match cow { + Cow::Borrowed(bs) => + Ok(Cow::Borrowed(std::str::from_utf8(bs).map_err(|_| io_syntax_error("Invalid UTF-8"))?)), + Cow::Owned(bs) => + Ok(Cow::Owned(std::str::from_utf8(&bs).map_err(|_| io_syntax_error("Invalid UTF-8"))?.to_owned())), + } +} + +pub fn decodebinary<'de>(minor: AtomMinor, bs: Cow<'de, [u8]>) -> IOResult { + Ok(match minor { + AtomMinor::SignedInteger => Value::from(decodeint(&bs)).wrap(), + AtomMinor::String => Value::String(decodestr(bs)?.into_owned()).wrap(), + AtomMinor::ByteString => Value::ByteString(bs.into_owned()).wrap(), + AtomMinor::Symbol => Value::symbol(&decodestr(bs)?).wrap(), + }) +} + +pub fn decodecompound<'de, I: Iterator>>>( + minor: CompoundMinor, + mut iter: I +) -> + IOResult +{ + match minor { + CompoundMinor::Record => + match iter.next() { + None => Err(io_syntax_error("Too few elements in encoded record")), + Some(labelres) => { + let label = labelres?.into_owned(); + let fields = iter.map(|r| r.map(|c| c.into_owned())).collect::>>()?; + Ok(Value::record(label, fields).wrap()) + } + } + CompoundMinor::Sequence => { + let vs = iter.map(|r| r.map(|c| c.into_owned())).collect::>>()?; + Ok(Value::Sequence(vs).wrap()) + } + CompoundMinor::Set => { + let mut s = Set::new(); + for res in iter { s.insert(res?.into_owned()); } + Ok(Value::Set(s).wrap()) + } + CompoundMinor::Dictionary => { + let mut d = Map::new(); + while let Some(kres) = iter.next() { + let k = kres?.into_owned(); + match iter.next() { + Some(vres) => { + let v = vres?.into_owned(); + d.insert(k, v); + } + None => return Err(io_syntax_error("Missing dictionary value")), + } + } + Ok(Value::Dictionary(d).wrap()) + } + } +} diff --git a/implementations/rust/src/value/value.rs b/implementations/rust/src/value/value.rs index 6fd2ca3..a373b78 100644 --- a/implementations/rust/src/value/value.rs +++ b/implementations/rust/src/value/value.rs @@ -1,8 +1,9 @@ use num::bigint::BigInt; use num::traits::cast::ToPrimitive; -use std::cmp::{Ordering}; +use std::cmp::Ordering; +use std::convert::TryFrom; use std::fmt::Debug; -use std::hash::{Hash,Hasher}; +use std::hash::{Hash, Hasher}; use std::ops::Index; use std::ops::IndexMut; use std::string::String; @@ -11,6 +12,8 @@ use std::vec::Vec; pub use std::collections::BTreeSet as Set; pub use std::collections::BTreeMap as Map; +use crate::error::{Error, ExpectedKind, Received}; + pub trait Domain: Sized + Debug + Clone + Eq + Hash + Ord { fn as_preserves(&self) -> Result { Err(std::io::Error::new(std::io::ErrorKind::InvalidData, @@ -30,6 +33,7 @@ pub trait NestedValue: Sized + Debug + Clone + Eq + Hash + Ord { fn annotations(&self) -> &[Self]; fn value(&self) -> &Value; + fn pieces(self) -> (Vec, Value); fn value_owned(self) -> Value; fn debug_fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -47,10 +51,6 @@ pub trait NestedValue: Sized + Debug + Clone + Eq + Hash + Ord { self.value().copy_via(f)) } - fn copy_via_id>(&self) -> M { - self.copy_via(&|d| Value::Domain(d.clone())) - } - fn to_io_value(&self) -> IOValue { self.copy_via(&|d| d.as_preserves().unwrap().value().clone()) } @@ -250,6 +250,10 @@ impl, D: Domain> Value { N::wrap(self) } + fn expected(&self, k: ExpectedKind) -> Error { + Error::Expected(k, Received::ReceivedOtherValue(self.clone().wrap().to_io_value())) + } + pub fn is_boolean(&self) -> bool { self.as_boolean().is_some() } @@ -270,6 +274,10 @@ impl, D: Domain> Value { } } + pub fn to_boolean(&self) -> Result { + self.as_boolean().ok_or_else(|| self.expected(ExpectedKind::Boolean)) + } + pub fn is_float(&self) -> bool { self.as_float().is_some() } @@ -290,6 +298,10 @@ impl, D: Domain> Value { } } + pub fn to_float(&self) -> Result { + self.as_float().ok_or_else(|| self.expected(ExpectedKind::Float)) + } + pub fn is_double(&self) -> bool { self.as_double().is_some() } @@ -310,6 +322,10 @@ impl, D: Domain> Value { } } + pub fn to_double(&self) -> Result { + self.as_double().ok_or_else(|| self.expected(ExpectedKind::Double)) + } + pub fn is_signedinteger(&self) -> bool { self.as_signedinteger().is_some() } @@ -330,6 +346,10 @@ impl, D: Domain> Value { } } + pub fn to_signedinteger(&self) -> Result<&BigInt, Error> { + self.as_signedinteger().ok_or_else(|| self.expected(ExpectedKind::SignedInteger)) + } + pub fn as_u8(&self) -> Option { self.as_signedinteger().and_then(|i| i.to_u8()) } pub fn as_i8(&self) -> Option { self.as_signedinteger().and_then(|i| i.to_i8()) } pub fn as_u16(&self) -> Option { self.as_signedinteger().and_then(|i| i.to_u16()) } @@ -339,6 +359,52 @@ impl, D: Domain> Value { pub fn as_u64(&self) -> Option { self.as_signedinteger().and_then(|i| i.to_u64()) } pub fn as_i64(&self) -> Option { self.as_signedinteger().and_then(|i| i.to_i64()) } + pub fn to_i8(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_i8().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_u8(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_u8().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_u16(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_u16().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_i16(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_i16().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_u32(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_u32().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_i32(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_i32().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_u64(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_u64().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_i64(&self) -> Result { + let i = self.to_signedinteger()?; + i.to_i64().ok_or_else(|| Error::NumberOutOfRange(i.clone())) + } + + pub fn to_char(&self) -> Result { + let fs = self.to_simple_record("UnicodeScalar", Some(1))?; + let c = fs[0].value().to_u32()?; + char::try_from(c).or_else(|_| Err(Error::InvalidUnicodeScalar(c))) + } + pub fn is_string(&self) -> bool { self.as_string().is_some() } @@ -359,6 +425,10 @@ impl, D: Domain> Value { } } + pub fn to_string(&self) -> Result<&String, Error> { + self.as_string().ok_or_else(|| self.expected(ExpectedKind::String)) + } + pub fn is_bytestring(&self) -> bool { self.as_bytestring().is_some() } @@ -379,6 +449,10 @@ impl, D: Domain> Value { } } + pub fn to_bytestring(&self) -> Result<&Vec, Error> { + self.as_bytestring().ok_or_else(|| self.expected(ExpectedKind::ByteString)) + } + pub fn symbol(s: &str) -> Value { Value::Symbol(s.to_string()) } @@ -403,6 +477,10 @@ impl, D: Domain> Value { } } + pub fn to_symbol(&self) -> Result<&String, Error> { + self.as_symbol().ok_or_else(|| self.expected(ExpectedKind::Symbol)) + } + pub fn record(label: N, fields: Vec) -> Value { Value::Record((label.boxwrap(), fields)) } @@ -435,6 +513,10 @@ impl, D: Domain> Value { } } + pub fn to_record(&self, arity: Option) -> Result<&Record, Error> { + self.as_record(arity).ok_or_else(|| self.expected(ExpectedKind::Record(arity))) + } + pub fn simple_record(label: &str, fields: Vec) -> Value { Value::record(Value::symbol(label).wrap(), fields) } @@ -452,6 +534,23 @@ impl, D: Domain> Value { }) } + pub fn to_simple_record(&self, label: &'static str, arity: Option) -> + Result<&Vec, Error> + { + self.as_simple_record(label, arity) + .ok_or_else(|| self.expected(ExpectedKind::SimpleRecord(label, arity))) + } + + pub fn to_option(&self) -> Result, Error> { + match self.as_simple_record("None", Some(0)) { + Some(_fs) => Ok(None), + None => match self.as_simple_record("Some", Some(1)) { + Some(fs) => Ok(Some(&fs[0])), + None => Err(self.expected(ExpectedKind::Option)) + } + } + } + pub fn is_sequence(&self) -> bool { self.as_sequence().is_some() } @@ -472,6 +571,10 @@ impl, D: Domain> Value { } } + pub fn to_sequence(&self) -> Result<&Vec, Error> { + self.as_sequence().ok_or_else(|| self.expected(ExpectedKind::Sequence)) + } + pub fn is_set(&self) -> bool { self.as_set().is_some() } @@ -492,6 +595,10 @@ impl, D: Domain> Value { } } + pub fn to_set(&self) -> Result<&Set, Error> { + self.as_set().ok_or_else(|| self.expected(ExpectedKind::Set)) + } + pub fn is_dictionary(&self) -> bool { self.as_dictionary().is_some() } @@ -512,6 +619,10 @@ impl, D: Domain> Value { } } + pub fn to_dictionary(&self) -> Result<&Map, Error> { + self.as_dictionary().ok_or_else(|| self.expected(ExpectedKind::Dictionary)) + } + pub fn copy_via, E: Domain, F>(&self, f: &F) -> Value where F: Fn(&D) -> Value @@ -660,6 +771,11 @@ impl NestedValue for PlainValue { &(self.0).1 } + fn pieces(self) -> (Vec, Value) { + let AnnotatedValue(anns, v) = self.0; + (anns, v) + } + fn value_owned(self) -> Value { (self.0).1 } @@ -717,6 +833,13 @@ impl NestedValue for RcValue { &(self.0).1 } + fn pieces(self) -> (Vec, Value) { + match Rc::try_unwrap(self.0) { + Ok(AnnotatedValue(anns, v)) => (anns, v), + Err(r) => (r.0.clone(), r.1.clone()), + } + } + fn value_owned(self) -> Value { Rc::try_unwrap(self.0).unwrap_or_else(|_| panic!("value_owned on RcValue with refcount greater than one")).1 } @@ -774,6 +897,13 @@ impl NestedValue for ArcValue { &(self.0).1 } + fn pieces(self) -> (Vec, Value) { + match Arc::try_unwrap(self.0) { + Ok(AnnotatedValue(anns, v)) => (anns, v), + Err(r) => (r.0.clone(), r.1.clone()), + } + } + fn value_owned(self) -> Value { Arc::try_unwrap(self.0).unwrap_or_else(|_| panic!("value_owned on ArcValue with refcount greater than one")).1 } @@ -807,6 +937,12 @@ impl Domain for NullDomain {} pub struct IOValue(Arc>); pub type UnwrappedIOValue = Value; +lazy_static! { + pub static ref FALSE: IOValue = IOValue(Arc::new(AnnotatedValue(Vec::new(), Value::Boolean(false)))); + pub static ref TRUE: IOValue = IOValue(Arc::new(AnnotatedValue(Vec::new(), Value::Boolean(true)))); + pub static ref EMPTY_SEQ: IOValue = IOValue(Arc::new(AnnotatedValue(Vec::new(), Value::Sequence(Vec::new())))); +} + impl NestedValue for IOValue { type BoxType = Self; @@ -834,6 +970,13 @@ impl NestedValue for IOValue { &(self.0).1 } + fn pieces(self) -> (Vec, Value) { + match Arc::try_unwrap(self.0) { + Ok(AnnotatedValue(anns, v)) => (anns, v), + Err(r) => (r.0.clone(), r.1.clone()), + } + } + fn value_owned(self) -> Value { match Arc::try_unwrap(self.0) { Ok(AnnotatedValue(_anns, v)) => v, diff --git a/implementations/rust/src/value/writer.rs b/implementations/rust/src/value/writer.rs index 81e2a46..95515da 100644 --- a/implementations/rust/src/value/writer.rs +++ b/implementations/rust/src/value/writer.rs @@ -8,7 +8,6 @@ pub type Result = std::result::Result<(), Error>; pub trait Writer { fn write_annotation_prefix(&mut self) -> Result; - fn write_placeholder_ref(&mut self, v: usize) -> Result; fn write_noop(&mut self) -> Result; fn write_bool(&mut self, v: bool) -> Result; @@ -86,10 +85,6 @@ impl Writer for W { write_header(self, Op::Misc(0), 5) } - fn write_placeholder_ref(&mut self, v: usize) -> Result { - write_header(self, Op::Misc(1), v) - } - fn write_noop(&mut self) -> Result { write_op(self, Op::Reserved(3), 15) }