preserves/implementations/python/preserves/binary.py

219 lines
7.5 KiB
Python

import numbers
import struct
from .values import *
from .error import *
from .compat import basestring_, ord_
from . import iolist
class BinaryCodec(object): pass
class Decoder(BinaryCodec):
def __init__(self, *, include_annotations=False, decode_embedded=lambda x: x):
self.include_annotations = include_annotations
self.decode_embedded = decode_embedded
def next(self, packet):
if not packet: raise ShortPacket('Short packet')
if not isinstance(packet, memoryview): packet = memoryview(packet)
tag = packet[0]
packet = packet[1:]
if tag == 0xA0: return self.wrap(False)
if tag == 0xA1: return self.wrap(True)
if tag == 0xA2:
if len(packet) == 4: return self.wrap(Float(struct.unpack('>f', packet)[0]))
if len(packet) == 8: return self.wrap(struct.unpack('>d', packet)[0])
raise DecodeError('Unsupported floating-point size ' + str(len(packet)))
if tag == 0xA3: return self.wrap(decode_int(packet))
if tag == 0xA4: return self.wrap(bytes(packet[:-1]).decode('utf-8'))
if tag == 0xA5: return self.wrap(bytes(packet))
if tag == 0xA6: return self.wrap(Symbol(bytes(packet).decode('utf-8')))
if tag == 0xA7:
vs = self.nextvalues(packet)
if not vs: raise DecodeError('Too few elements in encoded record')
return self.wrap(Record(vs[0], vs[1:]))
if tag == 0xA8: return self.wrap(tuple(self.nextvalues(packet)))
if tag == 0xA9: return self.wrap(frozenset(self.nextvalues(packet)))
if tag == 0xAA: return self.wrap(ImmutableDict.from_kvs(self.nextvalues(packet)))
if tag == 0xAB:
if self.decode_embedded is None:
raise DecodeError('No decode_embedded function supplied')
return self.wrap(Embedded(self.decode_embedded(self.next(packet))))
if tag == 0xBF:
if self.include_annotations:
vs = self.nextvalues(packet)
if not vs: raise DecodeError('No elements in annotation')
vs[0].annotations.extend(vs[1:])
return vs[0]
else:
e = self.nextitem(packet)
if e is None: raise DecodeError('No elements in annotation')
return e[0]
raise DecodeError('Invalid tag: ' + hex(tag))
def nextvalues(self, packet):
vs = []
while True:
e = self.nextitem(packet)
if e is None: return vs
vs.append(e[0])
packet = e[1]
def nextitem(self, packet):
if not packet: return None
(count, i) = decode_varint(packet)
item = packet[i:i+count]
packet = packet[i+count:]
return (self.next(item), packet)
def wrap(self, v):
return Annotated(v) if self.include_annotations else v
def try_next(self, packet):
try:
return self.next(packet)
except ShortPacket:
return None
def decode_varint(packet):
count = 0
for (i, b) in enumerate(packet):
if b & 0x80: return ((count << 7) + (b - 0x80), i + 1)
count = (count << 7) + b
raise ShortPacket('Short packet (incomplete length)')
def decode_int(packet):
if not packet: return 0
acc = packet[0]
if acc & 0x80: acc = acc - 256
for i in range(1, len(packet)):
acc = (acc << 8) | packet[i]
return acc
class StreamDecoder(object):
def __init__(self, initial_packet, decoder = None):
self.decoder = decoder or Decoder()
if not initial_packet:
raise DecodeError('Empty initial packet in StreamDecoder')
if initial_packet[0] != 0xA8:
raise DecodeError('Initial stream packet is not a Sequence')
self.buffer = memoryview(initial_packet[1:])
def extend(self, data):
self.buffer = memoryview(bytes(self.buffer) + data)
def __iter__(self):
return self
def __next__(self):
try:
e = self.decoder.next(self.buffer)
if e is None: raise StopIteration
self.buffer = e[1]
return e[0]
except ShortPacket:
raise StopIteration
def decode(bs, **kwargs):
return Decoder(**kwargs).next(bs)
def decode_with_annotations(bs, **kwargs):
return Decoder(include_annotations=True, **kwargs).next(bs)
class Encoder(BinaryCodec):
def __init__(self, *, encode_embedded=lambda x: x, canonicalize=False, include_annotations=None):
super(Encoder, self).__init__()
self.buffer = None
self._encode_embedded = encode_embedded
self._canonicalize = canonicalize
if include_annotations is None:
self.include_annotations = not canonicalize
else:
self.include_annotations = include_annotations
def reset(self):
self.buffer = None
def encode_embedded(self, v):
if self._encode_embedded is None:
raise EncodeError('No encode_embedded function supplied')
return self._encode_embedded(v)
def contents(self):
return iolist.bytes(self.buffer)
def lengthprefixed(self, encoded):
encoded = iolist.counted(encoded)
return [encode_varint(iolist.len(encoded)), encoded]
def encodeditem(self, v):
return self.lengthprefixed(self.encoded_iolist(v))
def encodedvalues(self, vs):
return [self.encodeditem(v) for v in vs]
def encoded(self, v):
return iolist.bytes(self.encoded_iolist(v))
def encoded_iolist(self, v):
v = preserve(v)
if hasattr(v, '__preserve_encoded__'): return v.__preserve_encoded__(self)
if v is False: return 0xA0
if v is True: return 0xA1
if isinstance(v, float): return [0xA2, struct.pack('>d', v)]
if isinstance(v, numbers.Number): return [0xA3, encode_int(v)]
if isinstance(v, bytes): return [0xA5, v]
if isinstance(v, basestring_): return [0xA4, v.encode('utf-8'), 0]
if isinstance(v, list): return [0xA8, self.encodedvalues(v)]
if isinstance(v, tuple): return [0xA8, self.encodedvalues(v)]
if isinstance(v, set) or isinstance(v, frozenset):
if self._canonicalize:
return [0xA9, [self.encodeditem(i)
for (_c, i) in sorted((canonicalize(i), i) for i in v)]]
else:
return [0xA9, self.encodedvalues(v)]
if isinstance(v, dict):
if self._canonicalize:
return [0xAA, [[self.encodeditem(k), self.encodeditem(v)]
for (_c, k, v) in sorted((canonicalize(k), k, v)
for (k, v) in v.items())]]
else:
return [0xAA, [[self.encodeditem(k), self.encodeditem(v)] for (k, v) in v.items()]]
try:
i = iter(v)
except TypeError:
i = None
if i is not None:
return [0xA8, self.encodedvalues(i)]
self.cannot_encode(v)
def cannot_encode(self, v):
raise TypeError('Cannot preserves-encode: ' + repr(v))
def encode_varint(n):
L = (n & 127) | 128
n = n >> 7
while n > 0:
L = [n & 127, L]
n = n >> 7
return L
def encode_int(v):
if v == 0: return None
if v == -1: return 255
bitcount = (~v if v < 0 else v).bit_length() + 1
bytecount = (bitcount + 7) // 8
D = None
for _i in range(bytecount):
D = [v & 255, D]
v = v >> 8
return D
def encode(v, **kwargs):
return Encoder(**kwargs).encoded(v)
def canonicalize(v, **kwargs):
return encode(v, canonicalize=True, **kwargs)