preserves/implementations/python/preserves/binary.py

255 lines
8.1 KiB
Python

import numbers
import struct
from .values import *
from .error import *
from .compat import basestring_, ord_
class BinaryCodec(object): pass
class Decoder(BinaryCodec):
def __init__(self, packet=b'', include_annotations=False, decode_embedded=lambda x: x):
super(Decoder, self).__init__()
self.packet = packet
self.index = 0
self.include_annotations = include_annotations
self.decode_embedded = decode_embedded
def extend(self, data):
self.packet = self.packet[self.index:] + data
self.index = 0
def nextbyte(self):
if self.index >= len(self.packet):
raise ShortPacket('Short packet')
self.index = self.index + 1
return ord_(self.packet[self.index - 1])
def nextbytes(self, n):
start = self.index
end = start + n
if end > len(self.packet):
raise ShortPacket('Short packet')
self.index = end
return self.packet[start : end]
def varint(self):
v = self.nextbyte()
if v < 128:
return v
else:
return self.varint() * 128 + (v - 128)
def peekend(self):
matched = (self.nextbyte() == 0x84)
if not matched:
self.index = self.index - 1
return matched
def nextvalues(self):
result = []
while not self.peekend():
result.append(self.next())
return result
def nextint(self, n):
if n == 0: return 0
acc = self.nextbyte()
if acc & 0x80: acc = acc - 256
for _i in range(n - 1):
acc = (acc << 8) | self.nextbyte()
return acc
def wrap(self, v):
return Annotated(v) if self.include_annotations else v
def unshift_annotation(self, a, v):
if self.include_annotations:
v.annotations.insert(0, a)
return v
def next(self):
tag = self.nextbyte()
if tag == 0x80: return self.wrap(False)
if tag == 0x81: return self.wrap(True)
if tag == 0x82: return self.wrap(Float.from_bytes(self.nextbytes(4)))
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
if tag == 0x85:
a = self.next()
v = self.next()
return self.unshift_annotation(a, v)
if tag == 0x86:
if self.decode_embedded is None:
raise DecodeError('No decode_embedded function supplied')
return self.wrap(Embedded(self.decode_embedded(self.next())))
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8'))
if tag == 0xb2: return self.wrap(self.nextbytes(self.varint()))
if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8')))
if tag == 0xb4:
vs = self.nextvalues()
if not vs: raise DecodeError('Too few elements in encoded record')
return self.wrap(Record(vs[0], vs[1:]))
if tag == 0xb5: return self.wrap(tuple(self.nextvalues()))
if tag == 0xb6: return self.wrap(frozenset(self.nextvalues()))
if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues()))
raise DecodeError('Invalid tag: ' + hex(tag))
def try_next(self):
start = self.index
try:
return self.next()
except ShortPacket:
self.index = start
return None
def __iter__(self):
return self
def __next__(self):
v = self.try_next()
if v is None:
raise StopIteration
return v
def decode(bs, **kwargs):
return Decoder(packet=bs, **kwargs).next()
def decode_with_annotations(bs, **kwargs):
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
class Encoder(BinaryCodec):
def __init__(self, encode_embedded=lambda x: x, canonicalize=False):
super(Encoder, self).__init__()
self.buffer = bytearray()
self._encode_embedded = encode_embedded
self._canonicalize = canonicalize
def reset(self):
self.buffer = bytearray()
def encode_embedded(self, v):
if self._encode_embedded is None:
raise EncodeError('No encode_embedded function supplied')
return self._encode_embedded(v)
def contents(self):
return bytes(self.buffer)
def varint(self, v):
if v < 128:
self.buffer.append(v)
else:
self.buffer.append((v % 128) + 128)
self.varint(v // 128)
def encodeint(self, v):
bitcount = (~v if v < 0 else v).bit_length() + 1
bytecount = (bitcount + 7) // 8
if bytecount <= 16:
self.buffer.append(0xa0 + bytecount - 1)
else:
self.buffer.append(0xb0)
self.varint(bytecount)
def enc(n,x):
if n > 0:
enc(n-1, x >> 8)
self.buffer.append(x & 255)
enc(bytecount, v)
def encodevalues(self, tag, items):
self.buffer.append(0xb0 + tag)
for i in items: self.append(i)
self.buffer.append(0x84)
def encodebytes(self, tag, bs):
self.buffer.append(0xb0 + tag)
self.varint(len(bs))
self.buffer.extend(bs)
def encodeset(self, v):
if not self._canonicalize:
self.encodevalues(6, v)
else:
c = Canonicalizer(self._encode_embedded)
for i in v: c.entry([i])
c.emit_entries(self, 6)
def encodedict(self, v):
if not self._canonicalize:
self.encodevalues(7, list(dict_kvs(v)))
else:
c = Canonicalizer(self._encode_embedded)
for (kk, vv) in v.items(): c.entry([kk, vv])
c.emit_entries(self, 7)
def append(self, v):
v = preserve(v)
if hasattr(v, '__preserve_write_binary__'):
v.__preserve_write_binary__(self)
elif v is False:
self.buffer.append(0x80)
elif v is True:
self.buffer.append(0x81)
elif isinstance(v, float):
self.buffer.append(0x83)
self.buffer.extend(struct.pack('>d', v))
elif isinstance(v, numbers.Number):
if v >= -3 and v <= 12:
self.buffer.append(0x90 + (v if v >= 0 else v + 16))
else:
self.encodeint(v)
elif isinstance(v, bytes):
self.encodebytes(2, v)
elif isinstance(v, basestring_):
self.encodebytes(1, v.encode('utf-8'))
elif isinstance(v, list):
self.encodevalues(5, v)
elif isinstance(v, tuple):
self.encodevalues(5, v)
elif isinstance(v, set):
self.encodeset(v)
elif isinstance(v, frozenset):
self.encodeset(v)
elif isinstance(v, dict):
self.encodedict(v)
else:
try:
i = iter(v)
except TypeError:
i = None
if i is None:
self.cannot_encode(v)
else:
self.encodevalues(5, i)
def cannot_encode(self, v):
raise TypeError('Cannot preserves-encode: ' + repr(v))
class Canonicalizer:
def __init__(self, encode_embedded):
self.encoder = Encoder(encode_embedded, canonicalize=True)
self.entries = []
def entry(self, pieces):
for piece in pieces: self.encoder.append(piece)
entry = self.encoder.contents()
self.encoder.reset()
self.entries.append(entry)
def emit_entries(self, outer_encoder, tag):
outer_encoder.buffer.append(0xb0 + tag)
for e in sorted(self.entries): outer_encoder.buffer.extend(e)
outer_encoder.buffer.append(0x84)
def encode(v, **kwargs):
e = Encoder(**kwargs)
e.append(v)
return e.contents()
def canonicalize(v, **kwargs):
return encode(v, canonicalize=True, **kwargs)