219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
import numbers
|
|
import struct
|
|
|
|
from .values import *
|
|
from .error import *
|
|
from .compat import basestring_, ord_
|
|
from . import iolist
|
|
|
|
class BinaryCodec(object): pass
|
|
|
|
class Decoder(BinaryCodec):
|
|
def __init__(self, *, include_annotations=False, decode_embedded=lambda x: x):
|
|
self.include_annotations = include_annotations
|
|
self.decode_embedded = decode_embedded
|
|
|
|
def next(self, packet):
|
|
if not packet: raise ShortPacket('Short packet')
|
|
if not isinstance(packet, memoryview): packet = memoryview(packet)
|
|
tag = packet[0]
|
|
packet = packet[1:]
|
|
if tag == 0xA0: return self.wrap(False)
|
|
if tag == 0xA1: return self.wrap(True)
|
|
if tag == 0xA2:
|
|
if len(packet) == 4: return self.wrap(Float(struct.unpack('>f', packet)[0]))
|
|
if len(packet) == 8: return self.wrap(struct.unpack('>d', packet)[0])
|
|
raise DecodeError('Unsupported floating-point size ' + str(len(packet)))
|
|
if tag == 0xA3: return self.wrap(decode_int(packet))
|
|
if tag == 0xA4: return self.wrap(bytes(packet[:-1]).decode('utf-8'))
|
|
if tag == 0xA5: return self.wrap(bytes(packet))
|
|
if tag == 0xA6: return self.wrap(Symbol(bytes(packet).decode('utf-8')))
|
|
if tag == 0xA7:
|
|
vs = self.nextvalues(packet)
|
|
if not vs: raise DecodeError('Too few elements in encoded record')
|
|
return self.wrap(Record(vs[0], vs[1:]))
|
|
if tag == 0xA8: return self.wrap(tuple(self.nextvalues(packet)))
|
|
if tag == 0xA9: return self.wrap(frozenset(self.nextvalues(packet)))
|
|
if tag == 0xAA: return self.wrap(ImmutableDict.from_kvs(self.nextvalues(packet)))
|
|
if tag == 0xAB:
|
|
if self.decode_embedded is None:
|
|
raise DecodeError('No decode_embedded function supplied')
|
|
return self.wrap(Embedded(self.decode_embedded(self.next(packet))))
|
|
if tag == 0xBF:
|
|
if self.include_annotations:
|
|
vs = self.nextvalues(packet)
|
|
if not vs: raise DecodeError('No elements in annotation')
|
|
vs[0].annotations.extend(vs[1:])
|
|
return vs[0]
|
|
else:
|
|
e = self.nextitem(packet)
|
|
if e is None: raise DecodeError('No elements in annotation')
|
|
return e[0]
|
|
raise DecodeError('Invalid tag: ' + hex(tag))
|
|
|
|
def nextvalues(self, packet):
|
|
vs = []
|
|
while True:
|
|
e = self.nextitem(packet)
|
|
if e is None: return vs
|
|
vs.append(e[0])
|
|
packet = e[1]
|
|
|
|
def nextitem(self, packet):
|
|
if not packet: return None
|
|
(count, i) = decode_varint(packet)
|
|
item = packet[i:i+count]
|
|
packet = packet[i+count:]
|
|
return (self.next(item), packet)
|
|
|
|
def wrap(self, v):
|
|
return Annotated(v) if self.include_annotations else v
|
|
|
|
def try_next(self, packet):
|
|
try:
|
|
return self.next(packet)
|
|
except ShortPacket:
|
|
return None
|
|
|
|
def decode_varint(packet):
|
|
count = 0
|
|
for (i, b) in enumerate(packet):
|
|
if b & 0x80: return ((count << 7) + (b - 0x80), i + 1)
|
|
count = (count << 7) + b
|
|
raise ShortPacket('Short packet (incomplete length)')
|
|
|
|
def decode_int(packet):
|
|
if not packet: return 0
|
|
acc = packet[0]
|
|
if acc & 0x80: acc = acc - 256
|
|
for i in range(1, len(packet)):
|
|
acc = (acc << 8) | packet[i]
|
|
return acc
|
|
|
|
class StreamDecoder(object):
|
|
def __init__(self, initial_packet, decoder = None):
|
|
self.decoder = decoder or Decoder()
|
|
if not initial_packet:
|
|
raise DecodeError('Empty initial packet in StreamDecoder')
|
|
if initial_packet[0] != 0xA8:
|
|
raise DecodeError('Initial stream packet is not a Sequence')
|
|
self.buffer = memoryview(initial_packet[1:])
|
|
|
|
def extend(self, data):
|
|
self.buffer = memoryview(bytes(self.buffer) + data)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
try:
|
|
e = self.decoder.next(self.buffer)
|
|
if e is None: raise StopIteration
|
|
self.buffer = e[1]
|
|
return e[0]
|
|
except ShortPacket:
|
|
raise StopIteration
|
|
|
|
def decode(bs, **kwargs):
|
|
return Decoder(**kwargs).next(bs)
|
|
|
|
def decode_with_annotations(bs, **kwargs):
|
|
return Decoder(include_annotations=True, **kwargs).next(bs)
|
|
|
|
class Encoder(BinaryCodec):
|
|
def __init__(self, *, encode_embedded=lambda x: x, canonicalize=False, include_annotations=None):
|
|
super(Encoder, self).__init__()
|
|
self.buffer = None
|
|
self._encode_embedded = encode_embedded
|
|
self._canonicalize = canonicalize
|
|
if include_annotations is None:
|
|
self.include_annotations = not canonicalize
|
|
else:
|
|
self.include_annotations = include_annotations
|
|
|
|
def reset(self):
|
|
self.buffer = None
|
|
|
|
def encode_embedded(self, v):
|
|
if self._encode_embedded is None:
|
|
raise EncodeError('No encode_embedded function supplied')
|
|
return self._encode_embedded(v)
|
|
|
|
def contents(self):
|
|
return iolist.bytes(self.buffer)
|
|
|
|
def lengthprefixed(self, encoded):
|
|
encoded = iolist.counted(encoded)
|
|
return [encode_varint(iolist.len(encoded)), encoded]
|
|
|
|
def encodeditem(self, v):
|
|
return self.lengthprefixed(self.encoded_iolist(v))
|
|
|
|
def encodedvalues(self, vs):
|
|
return [self.encodeditem(v) for v in vs]
|
|
|
|
def encoded(self, v):
|
|
return iolist.bytes(self.encoded_iolist(v))
|
|
|
|
def encoded_iolist(self, v):
|
|
v = preserve(v)
|
|
if hasattr(v, '__preserve_encoded__'): return v.__preserve_encoded__(self)
|
|
if v is False: return 0xA0
|
|
if v is True: return 0xA1
|
|
if isinstance(v, float): return [0xA2, struct.pack('>d', v)]
|
|
if isinstance(v, numbers.Number): return [0xA3, encode_int(v)]
|
|
if isinstance(v, bytes): return [0xA5, v]
|
|
if isinstance(v, basestring_): return [0xA4, v.encode('utf-8'), 0]
|
|
if isinstance(v, list): return [0xA8, self.encodedvalues(v)]
|
|
if isinstance(v, tuple): return [0xA8, self.encodedvalues(v)]
|
|
if isinstance(v, set) or isinstance(v, frozenset):
|
|
if self._canonicalize:
|
|
return [0xA9, [self.encodeditem(i)
|
|
for (_c, i) in sorted((canonicalize(i), i) for i in v)]]
|
|
else:
|
|
return [0xA9, self.encodedvalues(v)]
|
|
if isinstance(v, dict):
|
|
if self._canonicalize:
|
|
return [0xAA, [[self.encodeditem(k), self.encodeditem(v)]
|
|
for (_c, k, v) in sorted((canonicalize(k), k, v)
|
|
for (k, v) in v.items())]]
|
|
else:
|
|
return [0xAA, [[self.encodeditem(k), self.encodeditem(v)] for (k, v) in v.items()]]
|
|
try:
|
|
i = iter(v)
|
|
except TypeError:
|
|
i = None
|
|
if i is not None:
|
|
return [0xA8, self.encodedvalues(i)]
|
|
self.cannot_encode(v)
|
|
|
|
def cannot_encode(self, v):
|
|
raise TypeError('Cannot preserves-encode: ' + repr(v))
|
|
|
|
def encode_varint(n):
|
|
L = (n & 127) | 128
|
|
n = n >> 7
|
|
while n > 0:
|
|
L = [n & 127, L]
|
|
n = n >> 7
|
|
return L
|
|
|
|
def encode_int(v):
|
|
if v == 0: return None
|
|
if v == -1: return 255
|
|
|
|
bitcount = (~v if v < 0 else v).bit_length() + 1
|
|
bytecount = (bitcount + 7) // 8
|
|
|
|
D = None
|
|
for _i in range(bytecount):
|
|
D = [v & 255, D]
|
|
v = v >> 8
|
|
return D
|
|
|
|
def encode(v, **kwargs):
|
|
return Encoder(**kwargs).encoded(v)
|
|
|
|
def canonicalize(v, **kwargs):
|
|
return encode(v, canonicalize=True, **kwargs)
|