import numbers import struct from .values import * from .error import * from .compat import basestring_, ord_ from . import iolist class BinaryCodec(object): pass class Decoder(BinaryCodec): def __init__(self, *, include_annotations=False, decode_embedded=lambda x: x): self.include_annotations = include_annotations self.decode_embedded = decode_embedded def next(self, packet): if not packet: raise ShortPacket('Short packet') if not isinstance(packet, memoryview): packet = memoryview(packet) tag = packet[0] packet = packet[1:] if tag == 0xA0: return self.wrap(False) if tag == 0xA1: return self.wrap(True) if tag == 0xA2: if len(packet) == 4: return self.wrap(Float(struct.unpack('>f', packet)[0])) if len(packet) == 8: return self.wrap(struct.unpack('>d', packet)[0]) raise DecodeError('Unsupported floating-point size ' + str(len(packet))) if tag == 0xA3: return self.wrap(decode_int(packet)) if tag == 0xA4: return self.wrap(bytes(packet[:-1]).decode('utf-8')) if tag == 0xA5: return self.wrap(bytes(packet)) if tag == 0xA6: return self.wrap(Symbol(bytes(packet).decode('utf-8'))) if tag == 0xA7: vs = self.nextvalues(packet) if not vs: raise DecodeError('Too few elements in encoded record') return self.wrap(Record(vs[0], vs[1:])) if tag == 0xA8: return self.wrap(tuple(self.nextvalues(packet))) if tag == 0xA9: return self.wrap(frozenset(self.nextvalues(packet))) if tag == 0xAA: return self.wrap(ImmutableDict.from_kvs(self.nextvalues(packet))) if tag == 0xAB: if self.decode_embedded is None: raise DecodeError('No decode_embedded function supplied') return self.wrap(Embedded(self.decode_embedded(self.next(packet)))) if tag == 0xBF: if self.include_annotations: vs = self.nextvalues(packet) if not vs: raise DecodeError('No elements in annotation') vs[0].annotations.extend(vs[1:]) return vs[0] else: e = self.nextitem(packet) if e is None: raise DecodeError('No elements in annotation') return e[0] raise DecodeError('Invalid tag: ' + hex(tag)) def nextvalues(self, packet): vs = [] while True: e = self.nextitem(packet) if e is None: return vs vs.append(e[0]) packet = e[1] def nextitem(self, packet): if not packet: return None (count, i) = decode_varint(packet) item = packet[i:i+count] packet = packet[i+count:] return (self.next(item), packet) def wrap(self, v): return Annotated(v) if self.include_annotations else v def try_next(self, packet): try: return self.next(packet) except ShortPacket: return None def decode_varint(packet): count = 0 for (i, b) in enumerate(packet): if b & 0x80: return ((count << 7) + (b - 0x80), i + 1) count = (count << 7) + b raise ShortPacket('Short packet (incomplete length)') def decode_int(packet): if not packet: return 0 acc = packet[0] if acc & 0x80: acc = acc - 256 for i in range(1, len(packet)): acc = (acc << 8) | packet[i] return acc class StreamDecoder(object): def __init__(self, initial_packet, decoder = None): self.decoder = decoder or Decoder() if not initial_packet: raise DecodeError('Empty initial packet in StreamDecoder') if initial_packet[0] != 0xA8: raise DecodeError('Initial stream packet is not a Sequence') self.buffer = memoryview(initial_packet[1:]) def extend(self, data): self.buffer = memoryview(bytes(self.buffer) + data) def __iter__(self): return self def __next__(self): try: e = self.decoder.next(self.buffer) if e is None: raise StopIteration self.buffer = e[1] return e[0] except ShortPacket: raise StopIteration def decode(bs, **kwargs): return Decoder(**kwargs).next(bs) def decode_with_annotations(bs, **kwargs): return Decoder(include_annotations=True, **kwargs).next(bs) class Encoder(BinaryCodec): def __init__(self, *, encode_embedded=lambda x: x, canonicalize=False, include_annotations=None): super(Encoder, self).__init__() self.buffer = None self._encode_embedded = encode_embedded self._canonicalize = canonicalize if include_annotations is None: self.include_annotations = not canonicalize else: self.include_annotations = include_annotations def reset(self): self.buffer = None def encode_embedded(self, v): if self._encode_embedded is None: raise EncodeError('No encode_embedded function supplied') return self._encode_embedded(v) def contents(self): return iolist.bytes(self.buffer) def lengthprefixed(self, encoded): encoded = iolist.counted(encoded) return [encode_varint(iolist.len(encoded)), encoded] def encodeditem(self, v): return self.lengthprefixed(self.encoded_iolist(v)) def encodedvalues(self, vs): return [self.encodeditem(v) for v in vs] def encoded(self, v): return iolist.bytes(self.encoded_iolist(v)) def encoded_iolist(self, v): v = preserve(v) if hasattr(v, '__preserve_encoded__'): return v.__preserve_encoded__(self) if v is False: return 0xA0 if v is True: return 0xA1 if isinstance(v, float): return [0xA2, struct.pack('>d', v)] if isinstance(v, numbers.Number): return [0xA3, encode_int(v)] if isinstance(v, bytes): return [0xA5, v] if isinstance(v, basestring_): return [0xA4, v.encode('utf-8'), 0] if isinstance(v, list): return [0xA8, self.encodedvalues(v)] if isinstance(v, tuple): return [0xA8, self.encodedvalues(v)] if isinstance(v, set) or isinstance(v, frozenset): if self._canonicalize: return [0xA9, [self.encodeditem(i) for (_c, i) in sorted((canonicalize(i), i) for i in v)]] else: return [0xA9, self.encodedvalues(v)] if isinstance(v, dict): if self._canonicalize: return [0xAA, [[self.encodeditem(k), self.encodeditem(v)] for (_c, k, v) in sorted((canonicalize(k), k, v) for (k, v) in v.items())]] else: return [0xAA, [[self.encodeditem(k), self.encodeditem(v)] for (k, v) in v.items()]] try: i = iter(v) except TypeError: i = None if i is not None: return [0xA8, self.encodedvalues(i)] self.cannot_encode(v) def cannot_encode(self, v): raise TypeError('Cannot preserves-encode: ' + repr(v)) def encode_varint(n): L = (n & 127) | 128 n = n >> 7 while n > 0: L = [n & 127, L] n = n >> 7 return L def encode_int(v): if v == 0: return None if v == -1: return 255 bitcount = (~v if v < 0 else v).bit_length() + 1 bytecount = (bitcount + 7) // 8 D = None for _i in range(bytecount): D = [v & 255, D] v = v >> 8 return D def encode(v, **kwargs): return Encoder(**kwargs).encoded(v) def canonicalize(v, **kwargs): return encode(v, canonicalize=True, **kwargs)