Python preserves

This commit is contained in:
Tony Garnock-Jones 2018-09-25 15:53:56 +01:00
parent b6a3c480b3
commit 906f8a01b6
2 changed files with 556 additions and 0 deletions

385
syndicate/mc/preserve.py Normal file
View File

@ -0,0 +1,385 @@
import sys
import numbers
import struct
try:
basestring
except NameError:
basestring = str
if isinstance(chr(123), bytes):
_ord = ord
else:
_ord = lambda x: x
class Float(object):
def __init__(self, value):
self.value = value
def __eq__(self, other):
if other.__class__ is self.__class__:
return self.value == other.value
def __repr__(self):
return 'Float(' + repr(self.value) + ')'
def __preserve_on__(self, encoder):
encoder.leadbyte(0, 0, 2)
encoder.buffer.extend(struct.pack('>f', self.value))
class Symbol(object):
def __init__(self, name):
self.name = name
def __eq__(self, other):
return isinstance(other, Symbol) and self.name == other.name
def __hash__(self):
return hash(self.name)
def __repr__(self):
return '#' + self.name
def __preserve_on__(self, encoder):
bs = self.name.encode('utf-8')
encoder.header(1, 3, len(bs))
encoder.buffer.extend(bs)
class Record(object):
def __init__(self, key, fields):
self.key = key
self.fields = tuple(fields)
self.__hash = None
def __eq__(self, other):
return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields)
def __hash__(self):
if self.__hash is None:
self.__hash = hash((self.key, self.fields))
return self.__hash
def __repr__(self):
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
def __preserve_on__(self, encoder):
try:
index = encoder.shortForms.index(self.key)
except ValueError:
index = None
if index is None:
encoder.header(2, 3, len(self.fields) + 1)
encoder.append(self.key)
else:
encoder.header(2, index, len(self.fields))
for f in self.fields:
encoder.append(f)
# Blub blub blub
class ImmutableDict(dict):
def __init__(self, *args, **kwargs):
if hasattr(self, '__hash'): raise TypeError('Immutable')
super(ImmutableDict, self).__init__(*args, **kwargs)
self.__hash = None
def __delitem__(self, key): raise TypeError('Immutable')
def __setitem__(self, key, val): raise TypeError('Immutable')
def clear(self): raise TypeError('Immutable')
def pop(self, k, d=None): raise TypeError('Immutable')
def popitem(self): raise TypeError('Immutable')
def setdefault(self, k, d=None): raise TypeError('Immutable')
def update(self, e, **f): raise TypeError('Immutable')
def __hash__(self):
if self.__hash is None:
h = 0
for k in self:
h = ((h << 5) ^ (hash(k) << 2) ^ hash(self[k])) & sys.maxsize
self.__hash = h
return self.__hash
@staticmethod
def from_kvs(kvs):
i = iter(kvs)
result = ImmutableDict()
result_proxy = super(ImmutableDict, result)
try:
while True:
k = next(i)
v = next(i)
result_proxy.__setitem__(k, v)
except StopIteration:
pass
return result
def dict_kvs(d):
for k in d:
yield k
yield d[k]
class DecodeError(ValueError): pass
class EncodeError(ValueError): pass
class Codec(object):
def __init__(self):
self.shortForms = [Symbol(u'discard'), Symbol(u'capture'), Symbol(u'observe')]
def set_shortform(self, index, v):
if index >= 0 and index < 3:
self.shortForms[index] = v
else:
raise ValueError('Invalid short form index %r' % (index,))
class Stream(object):
def __init__(self, iterator):
self._iterator = iterator
def __preserve_on__(self, encoder):
arg = (self.major << 2) | self.minor
encoder.leadbyte(0, 2, arg)
self._emit(encoder)
encoder.leadbyte(0, 3, arg)
def _emit(self, encoder):
raise NotImplementedError('Should be implemented in subclasses')
class ValueStream(Stream):
major = 3
def _emit(self, encoder):
for v in self._iterator:
encoder.append(v)
class SequenceStream(ValueStream):
minor = 0
class SetStream(ValueStream):
minor = 1
class DictStream(ValueStream):
minor = 2
def _emit(self, encoder):
for (k, v) in self._iterator:
encoder.append(k)
encoder.append(v)
class BinaryStream(Stream):
major = 1
minor = 2
def _emit(self, encoder):
for chunk in self._iterator:
if not isinstance(chunk, bytes):
raise EncodeError('Illegal chunk in BinaryStream %r' % (chunk,))
encoder.append(chunk)
class StringStream(BinaryStream):
minor = 1
class SymbolStream(BinaryStream):
minor = 3
class Decoder(Codec):
def __init__(self, packet):
super(Decoder, self).__init__()
self.packet = packet
self.index = 0
def peekbyte(self):
if self.index < len(self.packet):
return _ord(self.packet[self.index])
else:
raise DecodeError('Short packet')
def advance(self, count=1):
start = self.index
self.index = self.index + count
return start
def nextbyte(self):
val = self.peekbyte()
self.advance()
return val
def wirelength(self, arg):
if arg < 15:
return arg
return self.varint()
def varint(self):
v = self.nextbyte()
if v < 128:
return v
else:
return self.varint() * 128 + (v - 128)
def nextbytes(self, n):
start = self.advance(n)
return self.packet[start : self.index]
def nextvalues(self, n):
result = []
for i in range(n):
result.append(self.next())
return result
def peekop(self):
b = self.peekbyte()
major = b >> 6
minor = (b >> 4) & 3
arg = b & 15
return (major, minor, arg)
def nextop(self):
op = self.peekop()
self.advance()
return op
def peekend(self, arg):
return self.peekop() == (0, 3, arg)
def binarystream(self, arg, minor):
result = []
while not self.peekend(arg):
chunk = self.next()
if isinstance(chunk, bytes):
result.append(chunk)
else:
raise DecodeError('Unexpected non-binary chunk')
return self.decodebinary(minor, b''.join(result))
def valuestream(self, arg, minor, decoder):
result = []
while not self.peekend(arg):
result.append(self.next())
return decoder(minor, result)
def decodeint(self, bs):
if len(bs) == 0: return 0
acc = _ord(bs[0])
if acc & 0x80: acc = acc - 256
for b in bs[1:]:
acc = (acc << 8) | _ord(b)
return acc
def decodebinary(self, minor, bs):
if minor == 0: return self.decodeint(bs)
if minor == 1: return bs.decode('utf-8')
if minor == 2: return bs
if minor == 3: return Symbol(bs.decode('utf-8'))
def decoderecord(self, minor, vs):
if minor == 3:
if not vs: raise DecodeError('Too few elements in encoded record')
return Record(vs[0], vs[1:])
else:
return Record(self.shortForms[minor], vs)
def decodecollection(self, minor, vs):
if minor == 0: return tuple(vs)
if minor == 1: return frozenset(vs)
if minor == 2: return ImmutableDict.from_kvs(vs)
if minor == 3: raise DecodeError('Invalid collection type')
def next(self):
(major, minor, arg) = self.nextop()
if major == 0:
if minor == 0:
if arg == 0: return False
if arg == 1: return True
if arg == 2: return Float(struct.unpack('>f', self.nextbytes(4))[0])
if arg == 3: return struct.unpack('>d', self.nextbytes(8))[0]
raise DecodeError('Invalid format A encoding')
elif minor == 1:
return arg - 16 if arg > 12 else arg
elif minor == 2:
t = arg >> 2
n = arg & 3
if t == 0: raise DecodeError('Invalid format C start byte')
if t == 1: return self.binarystream(arg, n)
if t == 2: return self.valuestream(arg, n, self.decoderecord)
if t == 3: return self.valuestream(arg, n, self.decodecollection)
else: # minor == 3
raise DecodeError('Unexpected format C end byte')
elif major == 1:
return self.decodebinary(minor, self.nextbytes(self.wirelength(arg)))
elif major == 2:
return self.decoderecord(minor, self.nextvalues(self.wirelength(arg)))
else: # major == 3
return self.decodecollection(minor, self.nextvalues(self.wirelength(arg)))
class Encoder(Codec):
def __init__(self):
super(Encoder, self).__init__()
self.buffer = bytearray()
def contents(self):
return bytes(self.buffer)
def varint(self, v):
if v < 128:
self.buffer.append(v)
else:
self.buffer.append((v % 128) + 128)
self.varint(v // 128)
def leadbyte(self, major, minor, arg):
self.buffer.append(((major & 3) << 6) | ((minor & 3) << 4) | (arg & 15))
def header(self, major, minor, wirelength):
if wirelength < 15:
self.leadbyte(major, minor, wirelength)
else:
self.leadbyte(major, minor, 15)
self.varint(wirelength)
def encodeint(self, v):
bitcount = (~v if v < 0 else v).bit_length() + 1
bytecount = (bitcount + 7) // 8
self.header(1, 0, bytecount)
def enc(n,x):
if n > 0:
enc(n-1, x >> 8)
self.buffer.append(x & 255)
enc(bytecount, v)
def encodecollection(self, minor, items):
self.header(3, minor, len(items))
for i in items: self.append(i)
def append(self, v):
if hasattr(v, '__preserve_on__'):
v.__preserve_on__(self)
elif v is False:
self.leadbyte(0, 0, 0)
elif v is True:
self.leadbyte(0, 0, 1)
elif isinstance(v, float):
self.leadbyte(0, 0, 3)
self.buffer.extend(struct.pack('>d', v))
elif isinstance(v, numbers.Number):
if v >= -3 and v <= 12:
self.leadbyte(0, 1, v if v >= 0 else v + 16)
else:
self.encodeint(v)
elif isinstance(v, bytes):
self.header(1, 2, len(v))
self.buffer.extend(v)
elif isinstance(v, basestring):
bs = v.encode('utf-8')
self.header(1, 1, len(bs))
self.buffer.extend(bs)
elif isinstance(v, list):
self.encodecollection(0, v)
elif isinstance(v, tuple):
self.encodecollection(0, v)
elif isinstance(v, set):
self.encodecollection(1, v)
elif isinstance(v, frozenset):
self.encodecollection(1, v)
elif isinstance(v, dict):
self.encodecollection(2, list(dict_kvs(v)))
else:
try:
i = iter(v)
except TypeError:
raise EncodeError('Cannot encode %r' % (v,))
self.encodestream(3, 0, i)

View File

@ -0,0 +1,171 @@
from preserve import *
import unittest
if isinstance(chr(123), bytes):
def _byte(x):
return chr(x)
def _hex(x):
return x.encode('hex')
else:
def _byte(x):
return bytes([x])
def _hex(x):
return x.hex()
def _buf(*args):
result = []
for chunk in args:
if isinstance(chunk, bytes):
result.append(chunk)
elif isinstance(chunk, basestring):
result.append(chunk.encode('utf-8'))
elif isinstance(chunk, numbers.Number):
result.append(_byte(chunk))
else:
raise Exception('Invalid chunk in _buf %r' % (chunk,))
result = b''.join(result)
return result
def _varint(v):
e = Encoder()
e.varint(v)
return e.contents()
def _d(bs):
d = Decoder(bs)
return d.next()
def _e(v):
e = Encoder()
e.append(v)
return e.contents()
def _R(k, *args):
return Record(Symbol(k), args)
class CodecTests(unittest.TestCase):
def _roundtrip(self, forward, expected, back=None, nondeterministic=False):
if back is None: back = forward
self.assertEqual(_d(_e(forward)), back)
self.assertEqual(_d(_e(back)), back)
self.assertEqual(_d(expected), back)
if not nondeterministic:
actual = _e(forward)
self.assertEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected)))
def test_decode_varint(self):
with self.assertRaises(DecodeError):
Decoder(_buf()).varint()
self.assertEqual(Decoder(_buf(0)).varint(), 0)
self.assertEqual(Decoder(_buf(10)).varint(), 10)
self.assertEqual(Decoder(_buf(100)).varint(), 100)
self.assertEqual(Decoder(_buf(200, 1)).varint(), 200)
self.assertEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300)
self.assertEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000)
def test_encode_varint(self):
self.assertEqual(_varint(0), _buf(0))
self.assertEqual(_varint(10), _buf(10))
self.assertEqual(_varint(100), _buf(100))
self.assertEqual(_varint(200), _buf(200, 1))
self.assertEqual(_varint(300), _buf(0b10101100, 0b00000010))
self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3))
def test_shorts(self):
self._roundtrip(_R('capture', _R('discard')), _buf(0x91, 0x80))
self._roundtrip(_R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))),
_buf(0xA1, 0xB3, 0x75, "speak", 0x80, 0x91, 0x80))
def test_simple_seq(self):
self._roundtrip([1,2,3,4], _buf(0xC4, 0x11, 0x12, 0x13, 0x14), back=(1,2,3,4))
self._roundtrip(SequenceStream([1,2,3,4]), _buf(0x2C, 0x11, 0x12, 0x13, 0x14, 0x3C),
back=(1,2,3,4))
self._roundtrip((-2,-1,0,1), _buf(0xC4, 0x1E, 0x1F, 0x10, 0x11))
def test_str(self):
self._roundtrip(u'hello', _buf(0x55, 'hello'))
self._roundtrip(StringStream([b'he', b'llo']), _buf(0x25, 0x62, 'he', 0x63, 'llo', 0x35),
back=u'hello')
self._roundtrip(StringStream([b'he', b'll', b'', b'', b'o']),
_buf(0x25, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x35),
back=u'hello')
self._roundtrip(BinaryStream([b'he', b'll', b'', b'', b'o']),
_buf(0x26, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x36),
back=b'hello')
self._roundtrip(SymbolStream([b'he', b'll', b'', b'', b'o']),
_buf(0x27, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x37),
back=Symbol(u'hello'))
def test_mixed1(self):
self._roundtrip((u'hello', Symbol(u'there'), b'world', (), set(), True, False),
_buf(0xc7, 0x55, 'hello', 0x75, 'there', 0x65, 'world', 0xc0, 0xd0, 1, 0))
def test_signedinteger(self):
self._roundtrip(-257, _buf(0x42, 0xFE, 0xFF))
self._roundtrip(-256, _buf(0x42, 0xFF, 0x00))
self._roundtrip(-255, _buf(0x42, 0xFF, 0x01))
self._roundtrip(-254, _buf(0x42, 0xFF, 0x02))
self._roundtrip(-129, _buf(0x42, 0xFF, 0x7F))
self._roundtrip(-128, _buf(0x41, 0x80))
self._roundtrip(-127, _buf(0x41, 0x81))
self._roundtrip(-4, _buf(0x41, 0xFC))
self._roundtrip(-3, _buf(0x1D))
self._roundtrip(-2, _buf(0x1E))
self._roundtrip(-1, _buf(0x1F))
self._roundtrip(0, _buf(0x10))
self._roundtrip(1, _buf(0x11))
self._roundtrip(12, _buf(0x1C))
self._roundtrip(13, _buf(0x41, 0x0D))
self._roundtrip(127, _buf(0x41, 0x7F))
self._roundtrip(128, _buf(0x42, 0x00, 0x80))
self._roundtrip(255, _buf(0x42, 0x00, 0xFF))
self._roundtrip(256, _buf(0x42, 0x01, 0x00))
self._roundtrip(32767, _buf(0x42, 0x7F, 0xFF))
self._roundtrip(32768, _buf(0x43, 0x00, 0x80, 0x00))
self._roundtrip(65535, _buf(0x43, 0x00, 0xFF, 0xFF))
self._roundtrip(65536, _buf(0x43, 0x01, 0x00, 0x00))
self._roundtrip(131072, _buf(0x43, 0x02, 0x00, 0x00))
def test_floats(self):
self._roundtrip(Float(1.0), _buf(2, 0x3f, 0x80, 0, 0))
self._roundtrip(1.0, _buf(3, 0x3f, 0xf0, 0, 0, 0, 0, 0, 0))
self._roundtrip(-1.202e300, _buf(3, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26))
def test_badchunks(self):
self.assertEqual(_d(_buf(0x25, 0x61, 'a', 0x35)), u'a')
self.assertEqual(_d(_buf(0x26, 0x61, 'a', 0x36)), b'a')
self.assertEqual(_d(_buf(0x27, 0x61, 'a', 0x37)), Symbol(u'a'))
for a in [0x25, 0x26, 0x27]:
for b in [0x51, 0x71]:
with self.assertRaises(DecodeError, msg='Unexpected non-binary chunk') as cm:
_d(_buf(a, b, 'a', 0x10+a))
def test_person(self):
self._roundtrip(Record((Symbol(u'titled'), Symbol(u'person'), 2, Symbol(u'thing'), 1),
[
101,
u'Blackwell',
_R(u'date', 1821, 2, 3),
u'Dr'
]),
_buf(0xB5, 0xC5, 0x76, 0x74, 0x69, 0x74, 0x6C, 0x65,
0x64, 0x76, 0x70, 0x65, 0x72, 0x73, 0x6F, 0x6E,
0x12, 0x75, 0x74, 0x68, 0x69, 0x6E, 0x67, 0x11,
0x41, 0x65, 0x59, 0x42, 0x6C, 0x61, 0x63, 0x6B,
0x77, 0x65, 0x6C, 0x6C, 0xB4, 0x74, 0x64, 0x61,
0x74, 0x65, 0x42, 0x07, 0x1D, 0x12, 0x13, 0x52,
0x44, 0x72))
def test_dict(self):
self._roundtrip({ Symbol(u'a'): 1,
u'b': True,
(1, 2, 3): b'c',
ImmutableDict({ Symbol(u'first-name'): u'Elizabeth', }):
{ Symbol(u'surname'): u'Blackwell' } },
_buf(0xE8,
0x71, "a", 0x11,
0x51, "b", 0x01,
0xC3, 0x11, 0x12, 0x13, 0x61, "c",
0xE2, 0x7A, "first-name", 0x59, "Elizabeth",
0xE2, 0x77, "surname", 0x59, "Blackwell"),
nondeterministic = True)