From 906f8a01b6c014b0e8adfb4ceb32ff679666c2f7 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Tue, 25 Sep 2018 15:53:56 +0100 Subject: [PATCH] Python preserves --- syndicate/mc/preserve.py | 385 ++++++++++++++++++++++++++++++++++ syndicate/mc/test_preserve.py | 171 +++++++++++++++ 2 files changed, 556 insertions(+) create mode 100644 syndicate/mc/preserve.py create mode 100644 syndicate/mc/test_preserve.py diff --git a/syndicate/mc/preserve.py b/syndicate/mc/preserve.py new file mode 100644 index 0000000..f399037 --- /dev/null +++ b/syndicate/mc/preserve.py @@ -0,0 +1,385 @@ +import sys +import numbers +import struct + +try: + basestring +except NameError: + basestring = str + +if isinstance(chr(123), bytes): + _ord = ord +else: + _ord = lambda x: x + +class Float(object): + def __init__(self, value): + self.value = value + + def __eq__(self, other): + if other.__class__ is self.__class__: + return self.value == other.value + + def __repr__(self): + return 'Float(' + repr(self.value) + ')' + + def __preserve_on__(self, encoder): + encoder.leadbyte(0, 0, 2) + encoder.buffer.extend(struct.pack('>f', self.value)) + +class Symbol(object): + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return isinstance(other, Symbol) and self.name == other.name + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return '#' + self.name + + def __preserve_on__(self, encoder): + bs = self.name.encode('utf-8') + encoder.header(1, 3, len(bs)) + encoder.buffer.extend(bs) + +class Record(object): + def __init__(self, key, fields): + self.key = key + self.fields = tuple(fields) + self.__hash = None + + def __eq__(self, other): + return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields) + + def __hash__(self): + if self.__hash is None: + self.__hash = hash((self.key, self.fields)) + return self.__hash + + def __repr__(self): + return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')' + + def __preserve_on__(self, encoder): + try: + index = encoder.shortForms.index(self.key) + except ValueError: + index = None + if index is None: + encoder.header(2, 3, len(self.fields) + 1) + encoder.append(self.key) + else: + encoder.header(2, index, len(self.fields)) + for f in self.fields: + encoder.append(f) + +# Blub blub blub +class ImmutableDict(dict): + def __init__(self, *args, **kwargs): + if hasattr(self, '__hash'): raise TypeError('Immutable') + super(ImmutableDict, self).__init__(*args, **kwargs) + self.__hash = None + + def __delitem__(self, key): raise TypeError('Immutable') + def __setitem__(self, key, val): raise TypeError('Immutable') + def clear(self): raise TypeError('Immutable') + def pop(self, k, d=None): raise TypeError('Immutable') + def popitem(self): raise TypeError('Immutable') + def setdefault(self, k, d=None): raise TypeError('Immutable') + def update(self, e, **f): raise TypeError('Immutable') + + def __hash__(self): + if self.__hash is None: + h = 0 + for k in self: + h = ((h << 5) ^ (hash(k) << 2) ^ hash(self[k])) & sys.maxsize + self.__hash = h + return self.__hash + + @staticmethod + def from_kvs(kvs): + i = iter(kvs) + result = ImmutableDict() + result_proxy = super(ImmutableDict, result) + try: + while True: + k = next(i) + v = next(i) + result_proxy.__setitem__(k, v) + except StopIteration: + pass + return result + +def dict_kvs(d): + for k in d: + yield k + yield d[k] + +class DecodeError(ValueError): pass +class EncodeError(ValueError): pass + +class Codec(object): + def __init__(self): + self.shortForms = [Symbol(u'discard'), Symbol(u'capture'), Symbol(u'observe')] + + def set_shortform(self, index, v): + if index >= 0 and index < 3: + self.shortForms[index] = v + else: + raise ValueError('Invalid short form index %r' % (index,)) + +class Stream(object): + def __init__(self, iterator): + self._iterator = iterator + + def __preserve_on__(self, encoder): + arg = (self.major << 2) | self.minor + encoder.leadbyte(0, 2, arg) + self._emit(encoder) + encoder.leadbyte(0, 3, arg) + + def _emit(self, encoder): + raise NotImplementedError('Should be implemented in subclasses') + +class ValueStream(Stream): + major = 3 + def _emit(self, encoder): + for v in self._iterator: + encoder.append(v) + +class SequenceStream(ValueStream): + minor = 0 + +class SetStream(ValueStream): + minor = 1 + +class DictStream(ValueStream): + minor = 2 + def _emit(self, encoder): + for (k, v) in self._iterator: + encoder.append(k) + encoder.append(v) + +class BinaryStream(Stream): + major = 1 + minor = 2 + def _emit(self, encoder): + for chunk in self._iterator: + if not isinstance(chunk, bytes): + raise EncodeError('Illegal chunk in BinaryStream %r' % (chunk,)) + encoder.append(chunk) + +class StringStream(BinaryStream): + minor = 1 + +class SymbolStream(BinaryStream): + minor = 3 + +class Decoder(Codec): + def __init__(self, packet): + super(Decoder, self).__init__() + self.packet = packet + self.index = 0 + + def peekbyte(self): + if self.index < len(self.packet): + return _ord(self.packet[self.index]) + else: + raise DecodeError('Short packet') + + def advance(self, count=1): + start = self.index + self.index = self.index + count + return start + + def nextbyte(self): + val = self.peekbyte() + self.advance() + return val + + def wirelength(self, arg): + if arg < 15: + return arg + return self.varint() + + def varint(self): + v = self.nextbyte() + if v < 128: + return v + else: + return self.varint() * 128 + (v - 128) + + def nextbytes(self, n): + start = self.advance(n) + return self.packet[start : self.index] + + def nextvalues(self, n): + result = [] + for i in range(n): + result.append(self.next()) + return result + + def peekop(self): + b = self.peekbyte() + major = b >> 6 + minor = (b >> 4) & 3 + arg = b & 15 + return (major, minor, arg) + + def nextop(self): + op = self.peekop() + self.advance() + return op + + def peekend(self, arg): + return self.peekop() == (0, 3, arg) + + def binarystream(self, arg, minor): + result = [] + while not self.peekend(arg): + chunk = self.next() + if isinstance(chunk, bytes): + result.append(chunk) + else: + raise DecodeError('Unexpected non-binary chunk') + return self.decodebinary(minor, b''.join(result)) + + def valuestream(self, arg, minor, decoder): + result = [] + while not self.peekend(arg): + result.append(self.next()) + return decoder(minor, result) + + def decodeint(self, bs): + if len(bs) == 0: return 0 + acc = _ord(bs[0]) + if acc & 0x80: acc = acc - 256 + for b in bs[1:]: + acc = (acc << 8) | _ord(b) + return acc + + def decodebinary(self, minor, bs): + if minor == 0: return self.decodeint(bs) + if minor == 1: return bs.decode('utf-8') + if minor == 2: return bs + if minor == 3: return Symbol(bs.decode('utf-8')) + + def decoderecord(self, minor, vs): + if minor == 3: + if not vs: raise DecodeError('Too few elements in encoded record') + return Record(vs[0], vs[1:]) + else: + return Record(self.shortForms[minor], vs) + + def decodecollection(self, minor, vs): + if minor == 0: return tuple(vs) + if minor == 1: return frozenset(vs) + if minor == 2: return ImmutableDict.from_kvs(vs) + if minor == 3: raise DecodeError('Invalid collection type') + + def next(self): + (major, minor, arg) = self.nextop() + if major == 0: + if minor == 0: + if arg == 0: return False + if arg == 1: return True + if arg == 2: return Float(struct.unpack('>f', self.nextbytes(4))[0]) + if arg == 3: return struct.unpack('>d', self.nextbytes(8))[0] + raise DecodeError('Invalid format A encoding') + elif minor == 1: + return arg - 16 if arg > 12 else arg + elif minor == 2: + t = arg >> 2 + n = arg & 3 + if t == 0: raise DecodeError('Invalid format C start byte') + if t == 1: return self.binarystream(arg, n) + if t == 2: return self.valuestream(arg, n, self.decoderecord) + if t == 3: return self.valuestream(arg, n, self.decodecollection) + else: # minor == 3 + raise DecodeError('Unexpected format C end byte') + elif major == 1: + return self.decodebinary(minor, self.nextbytes(self.wirelength(arg))) + elif major == 2: + return self.decoderecord(minor, self.nextvalues(self.wirelength(arg))) + else: # major == 3 + return self.decodecollection(minor, self.nextvalues(self.wirelength(arg))) + +class Encoder(Codec): + def __init__(self): + super(Encoder, self).__init__() + self.buffer = bytearray() + + def contents(self): + return bytes(self.buffer) + + def varint(self, v): + if v < 128: + self.buffer.append(v) + else: + self.buffer.append((v % 128) + 128) + self.varint(v // 128) + + def leadbyte(self, major, minor, arg): + self.buffer.append(((major & 3) << 6) | ((minor & 3) << 4) | (arg & 15)) + + def header(self, major, minor, wirelength): + if wirelength < 15: + self.leadbyte(major, minor, wirelength) + else: + self.leadbyte(major, minor, 15) + self.varint(wirelength) + + def encodeint(self, v): + bitcount = (~v if v < 0 else v).bit_length() + 1 + bytecount = (bitcount + 7) // 8 + self.header(1, 0, bytecount) + def enc(n,x): + if n > 0: + enc(n-1, x >> 8) + self.buffer.append(x & 255) + enc(bytecount, v) + + def encodecollection(self, minor, items): + self.header(3, minor, len(items)) + for i in items: self.append(i) + + def append(self, v): + if hasattr(v, '__preserve_on__'): + v.__preserve_on__(self) + elif v is False: + self.leadbyte(0, 0, 0) + elif v is True: + self.leadbyte(0, 0, 1) + elif isinstance(v, float): + self.leadbyte(0, 0, 3) + self.buffer.extend(struct.pack('>d', v)) + elif isinstance(v, numbers.Number): + if v >= -3 and v <= 12: + self.leadbyte(0, 1, v if v >= 0 else v + 16) + else: + self.encodeint(v) + elif isinstance(v, bytes): + self.header(1, 2, len(v)) + self.buffer.extend(v) + elif isinstance(v, basestring): + bs = v.encode('utf-8') + self.header(1, 1, len(bs)) + self.buffer.extend(bs) + elif isinstance(v, list): + self.encodecollection(0, v) + elif isinstance(v, tuple): + self.encodecollection(0, v) + elif isinstance(v, set): + self.encodecollection(1, v) + elif isinstance(v, frozenset): + self.encodecollection(1, v) + elif isinstance(v, dict): + self.encodecollection(2, list(dict_kvs(v))) + else: + try: + i = iter(v) + except TypeError: + raise EncodeError('Cannot encode %r' % (v,)) + self.encodestream(3, 0, i) diff --git a/syndicate/mc/test_preserve.py b/syndicate/mc/test_preserve.py new file mode 100644 index 0000000..ba3f0b8 --- /dev/null +++ b/syndicate/mc/test_preserve.py @@ -0,0 +1,171 @@ +from preserve import * +import unittest + +if isinstance(chr(123), bytes): + def _byte(x): + return chr(x) + def _hex(x): + return x.encode('hex') +else: + def _byte(x): + return bytes([x]) + def _hex(x): + return x.hex() + +def _buf(*args): + result = [] + for chunk in args: + if isinstance(chunk, bytes): + result.append(chunk) + elif isinstance(chunk, basestring): + result.append(chunk.encode('utf-8')) + elif isinstance(chunk, numbers.Number): + result.append(_byte(chunk)) + else: + raise Exception('Invalid chunk in _buf %r' % (chunk,)) + result = b''.join(result) + return result + +def _varint(v): + e = Encoder() + e.varint(v) + return e.contents() + +def _d(bs): + d = Decoder(bs) + return d.next() + +def _e(v): + e = Encoder() + e.append(v) + return e.contents() + +def _R(k, *args): + return Record(Symbol(k), args) + +class CodecTests(unittest.TestCase): + def _roundtrip(self, forward, expected, back=None, nondeterministic=False): + if back is None: back = forward + self.assertEqual(_d(_e(forward)), back) + self.assertEqual(_d(_e(back)), back) + self.assertEqual(_d(expected), back) + if not nondeterministic: + actual = _e(forward) + self.assertEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected))) + + def test_decode_varint(self): + with self.assertRaises(DecodeError): + Decoder(_buf()).varint() + self.assertEqual(Decoder(_buf(0)).varint(), 0) + self.assertEqual(Decoder(_buf(10)).varint(), 10) + self.assertEqual(Decoder(_buf(100)).varint(), 100) + self.assertEqual(Decoder(_buf(200, 1)).varint(), 200) + self.assertEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300) + self.assertEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000) + + def test_encode_varint(self): + self.assertEqual(_varint(0), _buf(0)) + self.assertEqual(_varint(10), _buf(10)) + self.assertEqual(_varint(100), _buf(100)) + self.assertEqual(_varint(200), _buf(200, 1)) + self.assertEqual(_varint(300), _buf(0b10101100, 0b00000010)) + self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3)) + + def test_shorts(self): + self._roundtrip(_R('capture', _R('discard')), _buf(0x91, 0x80)) + self._roundtrip(_R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))), + _buf(0xA1, 0xB3, 0x75, "speak", 0x80, 0x91, 0x80)) + + def test_simple_seq(self): + self._roundtrip([1,2,3,4], _buf(0xC4, 0x11, 0x12, 0x13, 0x14), back=(1,2,3,4)) + self._roundtrip(SequenceStream([1,2,3,4]), _buf(0x2C, 0x11, 0x12, 0x13, 0x14, 0x3C), + back=(1,2,3,4)) + self._roundtrip((-2,-1,0,1), _buf(0xC4, 0x1E, 0x1F, 0x10, 0x11)) + + def test_str(self): + self._roundtrip(u'hello', _buf(0x55, 'hello')) + self._roundtrip(StringStream([b'he', b'llo']), _buf(0x25, 0x62, 'he', 0x63, 'llo', 0x35), + back=u'hello') + self._roundtrip(StringStream([b'he', b'll', b'', b'', b'o']), + _buf(0x25, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x35), + back=u'hello') + self._roundtrip(BinaryStream([b'he', b'll', b'', b'', b'o']), + _buf(0x26, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x36), + back=b'hello') + self._roundtrip(SymbolStream([b'he', b'll', b'', b'', b'o']), + _buf(0x27, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x37), + back=Symbol(u'hello')) + + def test_mixed1(self): + self._roundtrip((u'hello', Symbol(u'there'), b'world', (), set(), True, False), + _buf(0xc7, 0x55, 'hello', 0x75, 'there', 0x65, 'world', 0xc0, 0xd0, 1, 0)) + + def test_signedinteger(self): + self._roundtrip(-257, _buf(0x42, 0xFE, 0xFF)) + self._roundtrip(-256, _buf(0x42, 0xFF, 0x00)) + self._roundtrip(-255, _buf(0x42, 0xFF, 0x01)) + self._roundtrip(-254, _buf(0x42, 0xFF, 0x02)) + self._roundtrip(-129, _buf(0x42, 0xFF, 0x7F)) + self._roundtrip(-128, _buf(0x41, 0x80)) + self._roundtrip(-127, _buf(0x41, 0x81)) + self._roundtrip(-4, _buf(0x41, 0xFC)) + self._roundtrip(-3, _buf(0x1D)) + self._roundtrip(-2, _buf(0x1E)) + self._roundtrip(-1, _buf(0x1F)) + self._roundtrip(0, _buf(0x10)) + self._roundtrip(1, _buf(0x11)) + self._roundtrip(12, _buf(0x1C)) + self._roundtrip(13, _buf(0x41, 0x0D)) + self._roundtrip(127, _buf(0x41, 0x7F)) + self._roundtrip(128, _buf(0x42, 0x00, 0x80)) + self._roundtrip(255, _buf(0x42, 0x00, 0xFF)) + self._roundtrip(256, _buf(0x42, 0x01, 0x00)) + self._roundtrip(32767, _buf(0x42, 0x7F, 0xFF)) + self._roundtrip(32768, _buf(0x43, 0x00, 0x80, 0x00)) + self._roundtrip(65535, _buf(0x43, 0x00, 0xFF, 0xFF)) + self._roundtrip(65536, _buf(0x43, 0x01, 0x00, 0x00)) + self._roundtrip(131072, _buf(0x43, 0x02, 0x00, 0x00)) + + def test_floats(self): + self._roundtrip(Float(1.0), _buf(2, 0x3f, 0x80, 0, 0)) + self._roundtrip(1.0, _buf(3, 0x3f, 0xf0, 0, 0, 0, 0, 0, 0)) + self._roundtrip(-1.202e300, _buf(3, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26)) + + def test_badchunks(self): + self.assertEqual(_d(_buf(0x25, 0x61, 'a', 0x35)), u'a') + self.assertEqual(_d(_buf(0x26, 0x61, 'a', 0x36)), b'a') + self.assertEqual(_d(_buf(0x27, 0x61, 'a', 0x37)), Symbol(u'a')) + for a in [0x25, 0x26, 0x27]: + for b in [0x51, 0x71]: + with self.assertRaises(DecodeError, msg='Unexpected non-binary chunk') as cm: + _d(_buf(a, b, 'a', 0x10+a)) + + def test_person(self): + self._roundtrip(Record((Symbol(u'titled'), Symbol(u'person'), 2, Symbol(u'thing'), 1), + [ + 101, + u'Blackwell', + _R(u'date', 1821, 2, 3), + u'Dr' + ]), + _buf(0xB5, 0xC5, 0x76, 0x74, 0x69, 0x74, 0x6C, 0x65, + 0x64, 0x76, 0x70, 0x65, 0x72, 0x73, 0x6F, 0x6E, + 0x12, 0x75, 0x74, 0x68, 0x69, 0x6E, 0x67, 0x11, + 0x41, 0x65, 0x59, 0x42, 0x6C, 0x61, 0x63, 0x6B, + 0x77, 0x65, 0x6C, 0x6C, 0xB4, 0x74, 0x64, 0x61, + 0x74, 0x65, 0x42, 0x07, 0x1D, 0x12, 0x13, 0x52, + 0x44, 0x72)) + + def test_dict(self): + self._roundtrip({ Symbol(u'a'): 1, + u'b': True, + (1, 2, 3): b'c', + ImmutableDict({ Symbol(u'first-name'): u'Elizabeth', }): + { Symbol(u'surname'): u'Blackwell' } }, + _buf(0xE8, + 0x71, "a", 0x11, + 0x51, "b", 0x01, + 0xC3, 0x11, 0x12, 0x13, 0x61, "c", + 0xE2, 0x7A, "first-name", 0x59, "Elizabeth", + 0xE2, 0x77, "surname", 0x59, "Blackwell"), + nondeterministic = True)