diff --git a/implementations/python/preserves/__init__.py b/implementations/python/preserves/__init__.py index 1e2bd24..220b325 100644 --- a/implementations/python/preserves/__init__.py +++ b/implementations/python/preserves/__init__.py @@ -1,7 +1,6 @@ -from .preserves import Float, Symbol, Record, ImmutableDict +from .repr import Float, Symbol, Record, ImmutableDict +from .repr import Annotated, is_annotated, strip_annotations, annotate -from .preserves import DecodeError, EncodeError, ShortPacket +from .error import DecodeError, EncodeError, ShortPacket -from .preserves import Decoder, Encoder, decode, decode_with_annotations, encode - -from .preserves import Annotated, is_annotated, strip_annotations, annotate +from .binary import Decoder, Encoder, decode, decode_with_annotations, encode diff --git a/implementations/python/preserves/binary.py b/implementations/python/preserves/binary.py new file mode 100644 index 0000000..d263d1f --- /dev/null +++ b/implementations/python/preserves/binary.py @@ -0,0 +1,196 @@ +import numbers +import struct + +from .repr import * +from .error import * +from .compat import basestring_, ord_ + +class Codec(object): pass + +class Decoder(Codec): + def __init__(self, packet=b'', include_annotations=False, decode_embedded=None): + super(Decoder, self).__init__() + self.packet = packet + self.index = 0 + self.include_annotations = include_annotations + self.decode_embedded = decode_embedded + + def extend(self, data): + self.packet = self.packet[self.index:] + data + self.index = 0 + + def nextbyte(self): + if self.index >= len(self.packet): + raise ShortPacket('Short packet') + self.index = self.index + 1 + return ord_(self.packet[self.index - 1]) + + def nextbytes(self, n): + start = self.index + end = start + n + if end > len(self.packet): + raise ShortPacket('Short packet') + self.index = end + return self.packet[start : end] + + def varint(self): + v = self.nextbyte() + if v < 128: + return v + else: + return self.varint() * 128 + (v - 128) + + def peekend(self): + matched = (self.nextbyte() == 0x84) + if not matched: + self.index = self.index - 1 + return matched + + def nextvalues(self): + result = [] + while not self.peekend(): + result.append(self.next()) + return result + + def nextint(self, n): + if n == 0: return 0 + acc = self.nextbyte() + if acc & 0x80: acc = acc - 256 + for _i in range(n - 1): + acc = (acc << 8) | self.nextbyte() + return acc + + def wrap(self, v): + return Annotated(v) if self.include_annotations else v + + def unshift_annotation(self, a, v): + if self.include_annotations: + v.annotations.insert(0, a) + return v + + def next(self): + tag = self.nextbyte() + if tag == 0x80: return self.wrap(False) + if tag == 0x81: return self.wrap(True) + if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0])) + if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0]) + if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker') + if tag == 0x85: + a = self.next() + v = self.next() + return self.unshift_annotation(a, v) + if tag == 0x86: + if self.decode_embedded is None: + raise DecodeError('No decode_embedded function supplied') + return self.wrap(self.decode_embedded(self.next())) + if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90)) + if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1)) + if tag == 0xb0: return self.wrap(self.nextint(self.varint())) + if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8')) + if tag == 0xb2: return self.wrap(self.nextbytes(self.varint())) + if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8'))) + if tag == 0xb4: + vs = self.nextvalues() + if not vs: raise DecodeError('Too few elements in encoded record') + return self.wrap(Record(vs[0], vs[1:])) + if tag == 0xb5: return self.wrap(tuple(self.nextvalues())) + if tag == 0xb6: return self.wrap(frozenset(self.nextvalues())) + if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues())) + raise DecodeError('Invalid tag: ' + hex(tag)) + + def try_next(self): + start = self.index + try: + return self.next() + except ShortPacket: + self.index = start + return None + +def decode(bs, **kwargs): + return Decoder(packet=bs, **kwargs).next() + +def decode_with_annotations(bs, **kwargs): + return Decoder(packet=bs, include_annotations=True, **kwargs).next() + +class Encoder(Codec): + def __init__(self, encode_embedded=id): + super(Encoder, self).__init__() + self.buffer = bytearray() + self.encode_embedded = encode_embedded + + def contents(self): + return bytes(self.buffer) + + def varint(self, v): + if v < 128: + self.buffer.append(v) + else: + self.buffer.append((v % 128) + 128) + self.varint(v // 128) + + def encodeint(self, v): + bitcount = (~v if v < 0 else v).bit_length() + 1 + bytecount = (bitcount + 7) // 8 + if bytecount <= 16: + self.buffer.append(0xa0 + bytecount - 1) + else: + self.buffer.append(0xb0) + self.varint(bytecount) + def enc(n,x): + if n > 0: + enc(n-1, x >> 8) + self.buffer.append(x & 255) + enc(bytecount, v) + + def encodevalues(self, tag, items): + self.buffer.append(0xb0 + tag) + for i in items: self.append(i) + self.buffer.append(0x84) + + def encodebytes(self, tag, bs): + self.buffer.append(0xb0 + tag) + self.varint(len(bs)) + self.buffer.extend(bs) + + def append(self, v): + if hasattr(v, '__preserve_on__'): + v.__preserve_on__(self) + elif v is False: + self.buffer.append(0x80) + elif v is True: + self.buffer.append(0x81) + elif isinstance(v, float): + self.buffer.append(0x83) + self.buffer.extend(struct.pack('>d', v)) + elif isinstance(v, numbers.Number): + if v >= -3 and v <= 12: + self.buffer.append(0x90 + (v if v >= 0 else v + 16)) + else: + self.encodeint(v) + elif isinstance(v, bytes): + self.encodebytes(2, v) + elif isinstance(v, basestring_): + self.encodebytes(1, v.encode('utf-8')) + elif isinstance(v, list): + self.encodevalues(5, v) + elif isinstance(v, tuple): + self.encodevalues(5, v) + elif isinstance(v, set): + self.encodevalues(6, v) + elif isinstance(v, frozenset): + self.encodevalues(6, v) + elif isinstance(v, dict): + self.encodevalues(7, list(dict_kvs(v))) + else: + try: + i = iter(v) + except TypeError: + self.buffer.append(0x86) + self.append(self.encode_embedded(v)) + return + self.encodevalues(5, i) + +def encode(v, **kwargs): + e = Encoder(**kwargs) + e.append(v) + return e.contents() diff --git a/implementations/python/preserves/compat.py b/implementations/python/preserves/compat.py new file mode 100644 index 0000000..c4d81b0 --- /dev/null +++ b/implementations/python/preserves/compat.py @@ -0,0 +1,9 @@ +try: + basestring_ = basestring +except NameError: + basestring_ = str + +if isinstance(chr(123), bytes): + ord_ = ord +else: + ord_ = lambda x: x diff --git a/implementations/python/preserves/error.py b/implementations/python/preserves/error.py new file mode 100644 index 0000000..d0aee8a --- /dev/null +++ b/implementations/python/preserves/error.py @@ -0,0 +1,3 @@ +class DecodeError(ValueError): pass +class EncodeError(ValueError): pass +class ShortPacket(DecodeError): pass diff --git a/implementations/python/preserves/preserves.py b/implementations/python/preserves/repr.py similarity index 51% rename from implementations/python/preserves/preserves.py rename to implementations/python/preserves/repr.py index 2b5710e..8356835 100644 --- a/implementations/python/preserves/preserves.py +++ b/implementations/python/preserves/repr.py @@ -1,16 +1,7 @@ import sys -import numbers import struct -try: - basestring -except NameError: - basestring = str - -if isinstance(chr(123), bytes): - _ord = ord -else: - _ord = lambda x: x +from .error import DecodeError class Float(object): def __init__(self, value): @@ -35,10 +26,7 @@ class Float(object): class Symbol(object): def __init__(self, name): - if isinstance(name, Symbol): - self.name = name.name - else: - self.name = name + self.name = name.name if isinstance(name, Symbol) else name def __eq__(self, other): return isinstance(other, Symbol) and self.name == other.name @@ -185,12 +173,6 @@ def dict_kvs(d): yield k yield d[k] -class DecodeError(ValueError): pass -class EncodeError(ValueError): pass -class ShortPacket(DecodeError): pass - -class Codec(object): pass - inf = float('inf') class Annotated(object): @@ -258,191 +240,3 @@ def annotate(v, *anns): for a in anns: v.annotations.append(a) return v - -class Decoder(Codec): - def __init__(self, packet=b'', include_annotations=False, decode_embedded=None): - super(Decoder, self).__init__() - self.packet = packet - self.index = 0 - self.include_annotations = include_annotations - self.decode_embedded = decode_embedded - - def extend(self, data): - self.packet = self.packet[self.index:] + data - self.index = 0 - - def nextbyte(self): - if self.index >= len(self.packet): - raise ShortPacket('Short packet') - self.index = self.index + 1 - return _ord(self.packet[self.index - 1]) - - def nextbytes(self, n): - start = self.index - end = start + n - if end > len(self.packet): - raise ShortPacket('Short packet') - self.index = end - return self.packet[start : end] - - def varint(self): - v = self.nextbyte() - if v < 128: - return v - else: - return self.varint() * 128 + (v - 128) - - def peekend(self): - matched = (self.nextbyte() == 0x84) - if not matched: - self.index = self.index - 1 - return matched - - def nextvalues(self): - result = [] - while not self.peekend(): - result.append(self.next()) - return result - - def nextint(self, n): - if n == 0: return 0 - acc = self.nextbyte() - if acc & 0x80: acc = acc - 256 - for _i in range(n - 1): - acc = (acc << 8) | self.nextbyte() - return acc - - def wrap(self, v): - return Annotated(v) if self.include_annotations else v - - def unshift_annotation(self, a, v): - if self.include_annotations: - v.annotations.insert(0, a) - return v - - def next(self): - tag = self.nextbyte() - if tag == 0x80: return self.wrap(False) - if tag == 0x81: return self.wrap(True) - if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0])) - if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0]) - if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker') - if tag == 0x85: - a = self.next() - v = self.next() - return self.unshift_annotation(a, v) - if tag == 0x86: - if self.decode_embedded is None: - raise DecodeError('No decode_embedded function supplied') - return self.wrap(self.decode_embedded(self.next())) - if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90)) - if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1)) - if tag == 0xb0: return self.wrap(self.nextint(self.varint())) - if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8')) - if tag == 0xb2: return self.wrap(self.nextbytes(self.varint())) - if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8'))) - if tag == 0xb4: - vs = self.nextvalues() - if not vs: raise DecodeError('Too few elements in encoded record') - return self.wrap(Record(vs[0], vs[1:])) - if tag == 0xb5: return self.wrap(tuple(self.nextvalues())) - if tag == 0xb6: return self.wrap(frozenset(self.nextvalues())) - if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues())) - raise DecodeError('Invalid tag: ' + hex(tag)) - - def try_next(self): - start = self.index - try: - return self.next() - except ShortPacket: - self.index = start - return None - -def decode(bs, **kwargs): - return Decoder(packet=bs, **kwargs).next() - -def decode_with_annotations(bs, **kwargs): - return Decoder(packet=bs, include_annotations=True, **kwargs).next() - -class Encoder(Codec): - def __init__(self, encode_embedded=id): - super(Encoder, self).__init__() - self.buffer = bytearray() - self.encode_embedded = encode_embedded - - def contents(self): - return bytes(self.buffer) - - def varint(self, v): - if v < 128: - self.buffer.append(v) - else: - self.buffer.append((v % 128) + 128) - self.varint(v // 128) - - def encodeint(self, v): - bitcount = (~v if v < 0 else v).bit_length() + 1 - bytecount = (bitcount + 7) // 8 - if bytecount <= 16: - self.buffer.append(0xa0 + bytecount - 1) - else: - self.buffer.append(0xb0) - self.varint(bytecount) - def enc(n,x): - if n > 0: - enc(n-1, x >> 8) - self.buffer.append(x & 255) - enc(bytecount, v) - - def encodevalues(self, tag, items): - self.buffer.append(0xb0 + tag) - for i in items: self.append(i) - self.buffer.append(0x84) - - def encodebytes(self, tag, bs): - self.buffer.append(0xb0 + tag) - self.varint(len(bs)) - self.buffer.extend(bs) - - def append(self, v): - if hasattr(v, '__preserve_on__'): - v.__preserve_on__(self) - elif v is False: - self.buffer.append(0x80) - elif v is True: - self.buffer.append(0x81) - elif isinstance(v, float): - self.buffer.append(0x83) - self.buffer.extend(struct.pack('>d', v)) - elif isinstance(v, numbers.Number): - if v >= -3 and v <= 12: - self.buffer.append(0x90 + (v if v >= 0 else v + 16)) - else: - self.encodeint(v) - elif isinstance(v, bytes): - self.encodebytes(2, v) - elif isinstance(v, basestring): - self.encodebytes(1, v.encode('utf-8')) - elif isinstance(v, list): - self.encodevalues(5, v) - elif isinstance(v, tuple): - self.encodevalues(5, v) - elif isinstance(v, set): - self.encodevalues(6, v) - elif isinstance(v, frozenset): - self.encodevalues(6, v) - elif isinstance(v, dict): - self.encodevalues(7, list(dict_kvs(v))) - else: - try: - i = iter(v) - except TypeError: - self.buffer.append(0x86) - self.append(self.encode_embedded(v)) - return - self.encodevalues(5, i) - -def encode(v, **kwargs): - e = Encoder(**kwargs) - e.append(v) - return e.contents() diff --git a/implementations/python/preserves/test_preserves.py b/implementations/python/preserves/test_preserves.py index 593a543..152e92b 100644 --- a/implementations/python/preserves/test_preserves.py +++ b/implementations/python/preserves/test_preserves.py @@ -1,6 +1,9 @@ -from .preserves import * -import unittest +import numbers import sys +import unittest + +from . import * +from .compat import basestring_, ord_ if isinstance(chr(123), bytes): def _byte(x): @@ -18,7 +21,7 @@ def _buf(*args): for chunk in args: if isinstance(chunk, bytes): result.append(chunk) - elif isinstance(chunk, basestring): + elif isinstance(chunk, basestring_): result.append(chunk.encode('utf-8')) elif isinstance(chunk, numbers.Number): result.append(_byte(chunk)) @@ -165,9 +168,8 @@ class CodecTests(unittest.TestCase): a2 = A(1) self.assertNotEqual(_e(a1), _e(a2)) self.assertEqual(_e(a1), _e(a1)) - from .preserves import _ord - self.assertEqual(_ord(_e(a1)[0]), 0x86) - self.assertEqual(_ord(_e(a2)[0]), 0x86) + self.assertEqual(ord_(_e(a1)[0]), 0x86) + self.assertEqual(ord_(_e(a2)[0]), 0x86) def test_decode_embedded_absent(self): with self.assertRaises(DecodeError):