preserves/implementations/rust/preserves/src/value/packed/reader.rs

676 lines
23 KiB
Rust

use crate::error::{self, ExpectedKind, io_eof};
use num::bigint::BigInt;
use num::traits::cast::{FromPrimitive, ToPrimitive};
use std::borrow::Cow;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::fmt::Debug;
use std::io;
use std::marker::PhantomData;
use super::constants::Tag;
use super::super::{
CompoundClass,
DomainDecode,
Map,
NestedValue,
NoEmbeddedDomainCodec,
Record,
Set,
Value,
boundary as B,
reader::{
Token,
Reader,
ReaderResult,
},
signed_integer::SignedInteger,
source::BinarySource,
};
#[derive(Debug)]
enum Continuation {
Skip { count: Option<u64> },
Sequence { count: Option<u64> },
}
pub struct PackedReaderSource<'de, 'src, S: BinarySource<'de>> {
pub source: &'src mut S,
stack: Vec<Continuation>,
count: Option<u64>,
expect_length: bool,
phantom: PhantomData<&'de ()>,
}
// TODO: inline PackedReaderSource.
pub struct PackedReader<'de, 'src, S: BinarySource<'de>> {
pub source: PackedReaderSource<'de, 'src, S>,
}
pub struct ReaderMark<SourceMark> {
source_mark: SourceMark,
stack_len: usize,
count: Option<u64>,
expect_length: bool,
}
impl<'de, 'src, S: BinarySource<'de>> PackedReaderSource<'de, 'src, S> {
fn advance(&mut self, delta: u64) -> bool {
if let Some(ref mut c) = self.count {
if *c < delta {
return false;
}
*c -= delta;
}
return true;
}
fn advance_or_eof(&mut self, delta: u64) -> io::Result<()> {
if self.advance(delta) { Ok(()) } else { Err(io_eof()) }
}
pub fn set_expected_count(&mut self, expected_count: u64) {
match self.count {
Some(n) =>
panic!("Attempt to set_expected_count to {} when count already {}",
expected_count,
n),
None =>
self.count = Some(expected_count),
}
}
#[inline(always)]
fn peek_noeof(&mut self) -> io::Result<u8> {
self.peek()?.ok_or_else(|| io_eof())
}
#[inline(always)]
fn read(&mut self) -> io::Result<u8> {
let v = self.peek_noeof()?;
self.skip()?;
Ok(v)
}
#[inline(always)]
fn varint(&mut self) -> io::Result<u64> {
let mut acc = self.read()? as u64;
if acc & 0x80 != 0 { return Ok(acc - 0x80) }
loop {
if acc & 0xfe0000000000000 != 0 {
return Err(self.syntax_error("Varint length marker overflow"));
}
acc <<= 7;
let v = self.read()? as u64;
if v & 0x80 != 0 { return Ok(acc + v - 0x80) }
acc += v;
}
}
fn narrow_to_len(&mut self) -> io::Result<Option<u64>> {
let item_len = self.varint()?;
if !self.advance(item_len) {
return Err(self.syntax_error("Bad item length"));
}
let remaining = self.count;
self.count = Some(item_len);
Ok(remaining)
}
fn narrow(&mut self) -> io::Result<()> {
if !self.expect_length {
self.expect_length = true;
self.stack.push(Continuation::Sequence { count: Some(0) });
} else {
let count = self.narrow_to_len()?;
self.stack.push(Continuation::Sequence { count });
}
Ok(())
}
fn narrow_to_annotated_value(&mut self) -> io::Result<()> {
let count = self.narrow_to_len()?;
self.stack.push(Continuation::Skip { count });
Ok(())
}
fn widen<R: Debug>(&mut self, r: R) -> io::Result<R> {
loop {
match self.stack.pop() {
None => break,
Some(Continuation::Skip { count }) => match count {
Some(n) => self.source.discard(n)?,
None => { self.source.read_to_end()?; }
}
Some(Continuation::Sequence { count }) => {
self.count = count;
break;
}
}
}
Ok(r)
}
#[inline(always)]
fn read_signed_integer(&mut self) -> io::Result<SignedInteger> {
match self.count {
Some(0) => Ok(SignedInteger::from(0_i128)),
Some(1) => Ok(SignedInteger::from(self.read()? as i8)),
None | Some(_) => self.read_long_signed_integer()
}
}
fn read_long_signed_integer(&mut self) -> io::Result<SignedInteger> {
let bs = self.read_to_end()?;
let count = bs.len();
if count == 0 {
Ok(SignedInteger::from(0_i128))
} else if (bs[0] & 0x80) == 0 {
// Positive or zero.
let mut i = 0;
while i < count && bs[i] == 0 { i += 1; }
if count - i <= 16 {
let mut buf = [0; 16];
buf[16 - (count - i)..].copy_from_slice(&bs[i..]);
Ok(SignedInteger::from(u128::from_be_bytes(buf)))
} else {
Ok(SignedInteger::from(
Cow::Owned(BigInt::from_bytes_be(num::bigint::Sign::Plus, &bs[i..]))))
}
} else {
// Negative.
let mut i = 0;
while i < count && bs[i] == 0xff { i += 1; }
if count - i <= 16 {
let mut buf = [0xff; 16];
buf[16 - (count - i)..].copy_from_slice(&bs[i..]);
Ok(SignedInteger::from(i128::from_be_bytes(buf)))
} else {
Ok(SignedInteger::from(
Cow::Owned(BigInt::from_signed_bytes_be(&bs))))
}
}
}
}
impl<'de, 'src, S: BinarySource<'de>> BinarySource<'de> for PackedReaderSource<'de, 'src, S> {
type Mark = ReaderMark<S::Mark>;
#[inline(always)]
fn mark(&mut self) -> io::Result<Self::Mark> {
Ok(ReaderMark {
source_mark: self.source.mark()?,
stack_len: self.stack.len(),
count: self.count,
expect_length: self.expect_length,
})
}
#[inline(always)]
fn restore(&mut self, mark: &Self::Mark) -> io::Result<()> {
let ReaderMark { source_mark, stack_len, count, expect_length } = mark;
if *stack_len > self.stack.len() {
panic!("Attempt to restore state to longer stack ({}) than exists ({})",
stack_len,
self.stack.len());
}
self.stack.truncate(*stack_len);
self.source.restore(source_mark)?;
self.count = *count;
self.expect_length = *expect_length;
Ok(())
}
fn input_position(&mut self) -> io::Result<Option<usize>> {
self.source.input_position()
}
#[inline(always)]
fn skip(&mut self) -> io::Result<()> {
self.advance_or_eof(1)?;
self.source.skip()
}
#[inline(always)]
fn peek(&mut self) -> io::Result<Option<u8>> {
match self.count {
Some(0) => Ok(None),
_ => self.source.peek(),
}
}
#[inline(always)]
fn discard(&mut self, count: u64) -> io::Result<()> {
self.advance_or_eof(count)?;
self.source.discard(count)
}
#[inline(always)]
fn readbytes(&mut self, count: u64) -> io::Result<Cow<'de, [u8]>> {
self.advance_or_eof(count)?;
self.source.readbytes(count)
}
#[inline(always)]
fn readbytes_into(&mut self, bs: &mut [u8]) -> io::Result<()> {
self.advance_or_eof(bs.len() as u64)?;
self.source.readbytes_into(bs)
}
#[inline(always)]
fn read_to_end(&mut self) -> io::Result<Cow<'de, [u8]>> {
match self.count {
Some(n) => self.readbytes(n),
None => self.source.read_to_end(),
}
}
}
fn out_of_range<I: Into<BigInt>>(i: I) -> error::Error {
error::Error::NumberOutOfRange(i.into())
}
impl<'de, 'src, S: BinarySource<'de>> PackedReader<'de, 'src, S> {
#[inline(always)]
pub fn new(source: &'src mut S) -> Self {
PackedReader {
source: PackedReaderSource {
source,
stack: Vec::new(),
count: None,
expect_length: false,
phantom: PhantomData,
},
}
}
#[inline(always)]
fn peek_tag(&mut self) -> io::Result<Tag> {
Ok(Tag::try_from(self.source.peek_noeof()?)?)
}
#[inline(always)]
fn read_tag(&mut self) -> io::Result<Tag> {
Ok(Tag::try_from(self.source.read()?)?)
}
fn expected(&mut self, k: ExpectedKind) -> error::Error {
error::Error::Expected(k)
}
fn next_nonannotation_tag(&mut self) -> io::Result<Tag> {
self.source.narrow()?;
loop {
let tag = self.read_tag()?;
if tag == Tag::Annotation {
self.source.narrow_to_annotated_value()?;
} else {
return Ok(tag);
}
}
}
fn next_atomic(&mut self, expected_tag: Tag, k: ExpectedKind) -> ReaderResult<Cow<'de, [u8]>> {
if self.next_nonannotation_tag()? == expected_tag {
let bs = self.source.read_to_end()?;
Ok(self.source.widen(bs)?)
} else {
Err(self.expected(k))
}
}
fn next_compound(&mut self, expected_tag: Tag, k: ExpectedKind) -> ReaderResult<()>
{
if self.next_nonannotation_tag()? == expected_tag {
Ok(())
} else {
Err(self.expected(k))
}
}
#[inline(always)]
fn next_unsigned<T: FromPrimitive + Debug, F>(&mut self, f: F) -> ReaderResult<T>
where
F: FnOnce(u128) -> Option<T>
{
match self.next_nonannotation_tag()? {
Tag::SignedInteger => {
let n = &self.source.read_signed_integer()?;
let i = n.try_into().map_err(|_| out_of_range(n))?;
Ok(self.source.widen(f(i).ok_or_else(|| out_of_range(i))?)?)
}
_ => Err(self.expected(ExpectedKind::SignedInteger))
}
}
#[inline(always)]
fn next_signed<T: FromPrimitive + Debug, F>(&mut self, f: F) -> ReaderResult<T>
where
F: FnOnce(i128) -> Option<T>
{
match self.next_nonannotation_tag()? {
Tag::SignedInteger => {
let n = &self.source.read_signed_integer()?;
let i = n.try_into().map_err(|_| out_of_range(n))?;
Ok(self.source.widen(f(i).ok_or_else(|| out_of_range(i))?)?)
}
_ => Err(self.expected(ExpectedKind::SignedInteger))
}
}
fn syntax_error(&mut self, message: &str) -> io::Error {
self.source.syntax_error(message)
}
fn _next<N: NestedValue, Dec: DomainDecode<N::Embedded>>(
&mut self,
read_annotations: bool,
decode_embedded: &mut Dec,
) -> io::Result<N> {
loop {
return Ok(match self.read_tag()? {
Tag::False => N::new(false),
Tag::True => N::new(true),
Tag::Float => {
let bs = self.source.read_to_end()?;
match bs.len() {
4 => Value::from(f32::from_bits(u32::from_be_bytes((&bs[..]).try_into().unwrap()))).wrap(),
8 => Value::from(f64::from_bits(u64::from_be_bytes((&bs[..]).try_into().unwrap()))).wrap(),
_ => Err(self.syntax_error("Invalid floating-point width"))?,
}
}
Tag::SignedInteger => Value::SignedInteger(self.source.read_signed_integer()?).wrap(),
Tag::String => {
let bs = self.source.read_to_end()?;
Value::String(self.decode_nul_str(bs)?.into_owned()).wrap()
}
Tag::ByteString => Value::ByteString(self.source.read_to_end()?.into_owned()).wrap(),
Tag::Symbol => {
let bs = self.source.read_to_end()?;
Value::Symbol(self.decodestr(bs)?.into_owned()).wrap()
},
Tag::Record => {
let mut vs = Vec::new();
while let Some(v) = self.next(read_annotations, decode_embedded)? {
vs.push(v);
}
if vs.is_empty() {
return Err(self.syntax_error("Too few elements in encoded record"))
}
Value::Record(Record(vs)).wrap()
}
Tag::Sequence => {
let mut vs = Vec::new();
while let Some(v) = self.next(read_annotations, decode_embedded)? {
vs.push(v);
}
Value::Sequence(vs).wrap()
}
Tag::Set => {
let mut s = Set::new();
while let Some(v) = self.next(read_annotations, decode_embedded)? {
s.insert(v);
}
Value::Set(s).wrap()
}
Tag::Dictionary => {
let mut d = Map::new();
while let Some(k) = self.next(read_annotations, decode_embedded)? {
match self.next(read_annotations, decode_embedded)? {
Some(v) => { d.insert(k, v); }
None => return Err(self.syntax_error("Missing dictionary value")),
}
}
Value::Dictionary(d).wrap()
}
Tag::Embedded => {
self.source.expect_length = false;
let d = decode_embedded.decode_embedded(self, read_annotations)?;
self.source.expect_length = true;
Value::Embedded(d).wrap()
}
Tag::Annotation => {
if read_annotations {
let underlying: Option<N> = self.next(read_annotations, decode_embedded)?;
match underlying {
Some(v) => {
let mut vs = Vec::new();
while let Some(v) = self.next(read_annotations, decode_embedded)? {
vs.push(v);
}
let (mut existing_annotations, v) = v.pieces();
existing_annotations.modify(|ws| ws.extend_from_slice(&vs[..]));
N::wrap(existing_annotations, v)
}
None => return Err(self.syntax_error("Missing value in encoded annotation")),
}
} else {
self.source.narrow_to_annotated_value()?;
continue;
}
}
});
}
}
#[inline(always)]
fn decodestr<'a>(&mut self, cow: Cow<'a, [u8]>) -> io::Result<Cow<'a, str>> {
match cow {
Cow::Borrowed(bs) =>
Ok(Cow::Borrowed(std::str::from_utf8(bs).map_err(|_| self.syntax_error("Invalid UTF-8"))?)),
Cow::Owned(bs) =>
Ok(Cow::Owned(String::from_utf8(bs).map_err(|_| self.syntax_error("Invalid UTF-8"))?)),
}
}
fn check_nul(&mut self, bs: &[u8]) -> io::Result<()> {
if bs.len() < 1 || bs[bs.len() - 1] != 0 {
return Err(self.syntax_error("Missing trailing NUL byte on string"));
}
Ok(())
}
fn decode_nul_str<'a>(&mut self, cow: Cow<'a, [u8]>) -> io::Result<Cow<'a, str>> {
match cow {
Cow::Borrowed(bs) => {
self.check_nul(bs)?;
self.decodestr(Cow::Borrowed(&bs[0..bs.len()-1]))
}
Cow::Owned(mut bs) => {
self.check_nul(&bs)?;
bs.truncate(bs.len() - 1);
self.decodestr(Cow::Owned(bs))
}
}
}
}
impl<'de, 'src, S: BinarySource<'de>> Reader<'de> for PackedReader<'de, 'src, S> {
fn next<N: NestedValue, Dec: DomainDecode<N::Embedded>>(
&mut self,
read_annotations: bool,
decode_embedded: &mut Dec,
) -> io::Result<Option<N>> {
match self.source.peek()? {
None => Ok(None),
Some(_) => {
self.source.narrow()?;
let v = self._next(read_annotations, decode_embedded)?;
self.source.widen(Some(v))
}
}
}
#[inline(always)]
fn open_record(&mut self) -> ReaderResult<()> {
self.next_compound(Tag::Record, ExpectedKind::Record)
}
#[inline(always)]
fn open_sequence(&mut self) -> ReaderResult<()> {
self.next_compound(Tag::Sequence, ExpectedKind::Sequence)
}
#[inline(always)]
fn open_set(&mut self) -> ReaderResult<()> {
self.next_compound(Tag::Set, ExpectedKind::Set)
}
#[inline(always)]
fn open_dictionary(&mut self) -> ReaderResult<()> {
self.next_compound(Tag::Dictionary, ExpectedKind::Dictionary)
}
#[inline(always)]
fn boundary(&mut self, _b: &B::Type) -> ReaderResult<()> {
Ok(())
}
#[inline(always)]
fn close_compound(&mut self, _b: &mut B::Type, _i: &B::Item) -> ReaderResult<bool> {
if self.source.peek()?.is_none() {
Ok(self.source.widen(true)?)
} else {
Ok(false)
}
}
#[inline(always)]
fn open_embedded(&mut self) -> ReaderResult<()> {
self.next_compound(Tag::Embedded, ExpectedKind::Embedded)?;
self.source.expect_length = false;
Ok(())
}
#[inline(always)]
fn close_embedded(&mut self) -> ReaderResult<()> {
self.source.expect_length = true;
Ok(self.source.widen(())?)
}
type Mark = <PackedReaderSource<'de, 'src, S> as BinarySource<'de>>::Mark;
#[inline(always)]
fn mark(&mut self) -> io::Result<Self::Mark> {
BinarySource::mark(&mut self.source)
}
#[inline(always)]
fn restore(&mut self, mark: &Self::Mark) -> io::Result<()> {
BinarySource::restore(&mut self.source, mark)
}
fn next_token<N: NestedValue, Dec: DomainDecode<N::Embedded>>(
&mut self,
read_embedded_annotations: bool,
decode_embedded: &mut Dec,
) -> io::Result<Token<N>> {
match self.source.peek()? {
None => self.source.widen(Token::End),
Some(_) => {
self.source.narrow()?;
loop {
return match self.peek_tag()? {
Tag::False |
Tag::True |
Tag::Float |
Tag::SignedInteger |
Tag::String |
Tag::ByteString |
Tag::Symbol => {
let v = self._next(false, &mut NoEmbeddedDomainCodec)?;
self.source.widen(Token::Atom(v))
}
Tag::Record => { self.source.skip()?; Ok(Token::Compound(CompoundClass::Record)) }
Tag::Sequence => { self.source.skip()?; Ok(Token::Compound(CompoundClass::Sequence)) }
Tag::Set => { self.source.skip()?; Ok(Token::Compound(CompoundClass::Set)) }
Tag::Dictionary => { self.source.skip()?; Ok(Token::Compound(CompoundClass::Dictionary)) }
Tag::Embedded => {
self.source.skip()?;
self.source.expect_length = false;
let t = Token::Embedded(decode_embedded.decode_embedded(
self, read_embedded_annotations)?);
self.source.expect_length = true;
self.source.widen(t)
}
Tag::Annotation => {
self.source.skip()?;
self.source.narrow_to_annotated_value()?;
continue;
}
}
}
}
}
}
#[inline(always)]
fn next_boolean(&mut self) -> ReaderResult<bool> {
match self.next_nonannotation_tag()? {
Tag::False => { Ok(self.source.widen(false)?) }
Tag::True => { Ok(self.source.widen(true)?) }
_ => Err(self.expected(ExpectedKind::Boolean))?,
}
}
fn next_signedinteger(&mut self) -> ReaderResult<SignedInteger> {
match self.next_nonannotation_tag()? {
Tag::SignedInteger => {
let i = self.source.read_signed_integer()?;
Ok(self.source.widen(i)?)
},
_ => Err(self.expected(ExpectedKind::SignedInteger))?
}
}
fn next_i8(&mut self) -> ReaderResult<i8> { self.next_signed(|n| n.to_i8()) }
fn next_i16(&mut self) -> ReaderResult<i16> { self.next_signed(|n| n.to_i16()) }
fn next_i32(&mut self) -> ReaderResult<i32> { self.next_signed(|n| n.to_i32()) }
fn next_i64(&mut self) -> ReaderResult<i64> { self.next_signed(|n| n.to_i64()) }
fn next_i128(&mut self) -> ReaderResult<i128> { self.next_signed(|n| n.to_i128()) }
fn next_u8(&mut self) -> ReaderResult<u8> { self.next_unsigned(|n| n.to_u8()) }
fn next_u16(&mut self) -> ReaderResult<u16> { self.next_unsigned(|n| n.to_u16()) }
fn next_u32(&mut self) -> ReaderResult<u32> { self.next_unsigned(|n| n.to_u32()) }
fn next_u64(&mut self) -> ReaderResult<u64> { self.next_unsigned(|n| n.to_u64()) }
fn next_u128(&mut self) -> ReaderResult<u128> { self.next_unsigned(|n| n.to_u128()) }
fn next_f32(&mut self) -> ReaderResult<f32> {
let bs = self.next_atomic(Tag::Float, ExpectedKind::Float)?;
match bs.len() {
4 => Ok(f32::from_bits(u32::from_be_bytes((&bs[..]).try_into().unwrap()))),
8 => Ok(f64::from_bits(u64::from_be_bytes((&bs[..]).try_into().unwrap())) as f32),
_ => Err(self.syntax_error("Invalid floating-point width"))?,
}
}
fn next_f64(&mut self) -> ReaderResult<f64> {
let bs = self.next_atomic(Tag::Float, ExpectedKind::Double)?;
match bs.len() {
4 => Ok(f32::from_bits(u32::from_be_bytes((&bs[..]).try_into().unwrap())) as f64),
8 => Ok(f64::from_bits(u64::from_be_bytes((&bs[..]).try_into().unwrap()))),
_ => Err(self.syntax_error("Invalid floating-point width"))?,
}
}
fn next_str(&mut self) -> ReaderResult<Cow<'de, str>> {
let bs = self.next_atomic(Tag::String, ExpectedKind::String)?;
Ok(self.decode_nul_str(bs)?)
}
fn next_bytestring(&mut self) -> ReaderResult<Cow<'de, [u8]>> {
self.next_atomic(Tag::ByteString, ExpectedKind::ByteString)
}
fn next_symbol(&mut self) -> ReaderResult<Cow<'de, str>> {
let bs = self.next_atomic(Tag::Symbol, ExpectedKind::Symbol)?;
Ok(self.decodestr(bs)?)
}
}