Update Python codec; first round of Python test updates

This commit is contained in:
Tony Garnock-Jones 2019-08-30 22:53:01 +01:00
parent 942fa30d9d
commit e35c237c34
3 changed files with 195 additions and 115 deletions

View File

@ -4,5 +4,7 @@ from .preserves import DecodeError, EncodeError, ShortPacket
from .preserves import Decoder, Encoder
from .preserves import Annotated, is_annotated, strip_annotations, annotate
from .preserves import Stream, ValueStream, SequenceStream, SetStream, DictStream
from .preserves import BinaryStream, StringStream, SymbolStream

View File

@ -20,6 +20,12 @@ class Float(object):
if other.__class__ is self.__class__:
return self.value == other.value
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.value)
def __repr__(self):
return 'Float(' + repr(self.value) + ')'
@ -34,6 +40,9 @@ class Symbol(object):
def __eq__(self, other):
return isinstance(other, Symbol) and self.name == other.name
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.name)
@ -54,6 +63,9 @@ class Record(object):
def __eq__(self, other):
return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
if self.__hash is None:
self.__hash = hash((self.key, self.fields))
@ -63,15 +75,8 @@ class Record(object):
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
def __preserve_on__(self, encoder):
try:
index = encoder.shortForms.index(self.key)
except ValueError:
index = None
if index is None:
encoder.header(2, 3, len(self.fields) + 1)
encoder.append(self.key)
else:
encoder.header(2, index, len(self.fields))
encoder.header(2, 0, len(self.fields) + 1)
encoder.append(self.key)
for f in self.fields:
encoder.append(f)
@ -119,6 +124,9 @@ class RecordConstructorInfo(object):
return isinstance(other, RecordConstructorInfo) and \
(self.key, self.arity) == (other.key, other.arity)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
if self.__hash is None:
self.__hash = hash((self.key, self.arity))
@ -173,15 +181,7 @@ class DecodeError(ValueError): pass
class EncodeError(ValueError): pass
class ShortPacket(DecodeError): pass
class Codec(object):
def __init__(self):
self.shortForms = [Symbol(u'discard'), Symbol(u'capture'), Symbol(u'observe')]
def set_shortform(self, index, v):
if index >= 0 and index < 3:
self.shortForms[index] = v
else:
raise ValueError('Invalid short form index %r' % (index,))
class Codec(object): pass
class Stream(object):
def __init__(self, iterator):
@ -191,25 +191,25 @@ class Stream(object):
arg = (self.major << 2) | self.minor
encoder.leadbyte(0, 2, arg)
self._emit(encoder)
encoder.leadbyte(0, 3, arg)
encoder.leadbyte(0, 0, 4)
def _emit(self, encoder):
raise NotImplementedError('Should be implemented in subclasses')
class ValueStream(Stream):
major = 3
major = 2
def _emit(self, encoder):
for v in self._iterator:
encoder.append(v)
class SequenceStream(ValueStream):
minor = 0
class SetStream(ValueStream):
minor = 1
class DictStream(ValueStream):
class SetStream(ValueStream):
minor = 2
class DictStream(ValueStream):
minor = 3
def _emit(self, encoder):
for (k, v) in self._iterator:
encoder.append(k)
@ -230,11 +230,75 @@ class StringStream(BinaryStream):
class SymbolStream(BinaryStream):
minor = 3
inf = float('inf')
class Annotated(object):
def __init__(self, item):
self.annotations = []
self.item = item
def __preserve_on__(self, encoder):
for a in self.annotations:
encoder.header(0, 0, 5)
encoder.append(a)
encoder.append(self.item)
def strip(self, depth=inf):
return strip_annotations(self, depth)
def __eq__(self, other):
if other.__class__ is self.__class__:
return self.item == other.item
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.item)
def is_annotated(v):
return isinstance(v, Annotated)
def strip_annotations(v, depth=inf):
if depth == 0: return v
if not is_annotated(v): return v
next_depth = depth - 1
def walk(v):
strip_annotations(v, next_depth)
v = v.item
if isinstance(v, Record):
return Record(walk(v.key), tuple(walk(f) for f in v.fields))
elif isinstance(v, list):
return tuple(walk(f) for f in v)
elif isinstance(v, tuple):
return tuple(walk(f) for f in v)
elif isinstance(v, set):
return frozenset(walk(f) for f in v)
elif isinstance(v, frozenset):
return frozenset(walk(f) for f in v)
elif isinstance(v, dict):
return ImmutableDict.from_kvs(walk(f) for f in dict_kvs(v))
elif is_annotated(v):
raise ValueError('Improper annotation structure')
else:
return v
def annotate(v, *anns):
if not is_annotated(v):
v = Annotated(v)
for a in anns:
v.annotations.append(a)
return v
class Decoder(Codec):
def __init__(self, packet=b''):
def __init__(self, packet=b'', placeholders={}, include_annotations=False):
super(Decoder, self).__init__()
self.packet = packet
self.index = 0
self.placeholders = placeholders
self.include_annotations = include_annotations
def extend(self, data):
self.packet = self.packet[self.index:] + data
@ -279,15 +343,15 @@ class Decoder(Codec):
arg = b & 15
return (major, minor, arg)
def peekend(self, arg):
matched = (self.nextop() == (0, 3, arg))
def peekend(self):
matched = (self.nextbyte() == 4)
if not matched:
self.index = self.index - 1
return matched
def binarystream(self, arg, minor):
def binarystream(self, minor):
result = []
while not self.peekend(arg):
while not self.peekend():
chunk = self.next()
if isinstance(chunk, bytes):
result.append(chunk)
@ -295,11 +359,11 @@ class Decoder(Codec):
raise DecodeError('Unexpected non-binary chunk')
return self.decodebinary(minor, b''.join(result))
def valuestream(self, arg, minor, decoder):
def valuestream(self, minor):
result = []
while not self.peekend(arg):
while not self.peekend():
result.append(self.next())
return decoder(minor, result)
return self.decodecompound(minor, result)
def decodeint(self, bs):
if len(bs) == 0: return 0
@ -315,45 +379,56 @@ class Decoder(Codec):
if minor == 2: return bs
if minor == 3: return Symbol(bs.decode('utf-8'))
def decoderecord(self, minor, vs):
if minor == 3:
def decodecompound(self, minor, vs):
if minor == 0:
if not vs: raise DecodeError('Too few elements in encoded record')
return Record(vs[0], vs[1:])
else:
return Record(self.shortForms[minor], vs)
if minor == 1: return tuple(vs)
if minor == 2: return frozenset(vs)
if minor == 3: return ImmutableDict.from_kvs(vs)
def decodecollection(self, minor, vs):
if minor == 0: return tuple(vs)
if minor == 1: return frozenset(vs)
if minor == 2: return ImmutableDict.from_kvs(vs)
if minor == 3: raise DecodeError('Invalid collection type')
def wrap(self, v):
return Annotated(v) if self.include_annotations else v
def unshift_annotation(self, a, v):
if this.include_annotations:
v.annotations.insert(0, a)
return v
def next(self):
(major, minor, arg) = self.nextop()
if major == 0:
if minor == 0:
if arg == 0: return False
if arg == 1: return True
if arg == 2: return Float(struct.unpack('>f', self.nextbytes(4))[0])
if arg == 3: return struct.unpack('>d', self.nextbytes(8))[0]
if arg == 0: return self.wrap(False)
if arg == 1: return self.wrap(True)
if arg == 2: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
if arg == 3: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
if arg == 4: raise DecodeError('Unexpected end-of-stream marker')
if arg == 5:
a = self.next()
v = self.next()
return self.unshift_annotation(a, v)
raise DecodeError('Invalid format A encoding')
elif minor == 1:
return arg - 16 if arg > 12 else arg
n = self.wirelength(arg)
v = self.placeholders.get(n, None)
if v is None:
raise DecodeError('Invalid Preserves placeholder')
return self.wrap(v)
elif minor == 2:
t = arg >> 2
n = arg & 3
if t == 0: raise DecodeError('Invalid format C start byte')
if t == 1: return self.binarystream(arg, n)
if t == 2: return self.valuestream(arg, n, self.decoderecord)
if t == 3: return self.valuestream(arg, n, self.decodecollection)
if t == 1: return self.wrap(self.binarystream(n))
if t == 2: return self.wrap(self.valuestream(n))
raise DecodeError('Invalid format C start byte')
else: # minor == 3
raise DecodeError('Unexpected format C end byte')
return self.wrap(arg - 16 if arg > 12 else arg)
elif major == 1:
return self.decodebinary(minor, self.nextbytes(self.wirelength(arg)))
return self.wrap(self.decodebinary(minor, self.nextbytes(self.wirelength(arg))))
elif major == 2:
return self.decoderecord(minor, self.nextvalues(self.wirelength(arg)))
return self.wrap(self.decodecompound(minor, self.nextvalues(self.wirelength(arg))))
else: # major == 3
return self.decodecollection(minor, self.nextvalues(self.wirelength(arg)))
raise DecodeError('Invalid lead byte (major 3)')
def try_next(self):
start = self.index
@ -364,9 +439,10 @@ class Decoder(Codec):
return None
class Encoder(Codec):
def __init__(self):
def __init__(self, placeholders={}):
super(Encoder, self).__init__()
self.buffer = bytearray()
self.placeholders = placeholders
def contents(self):
return bytes(self.buffer)
@ -399,17 +475,23 @@ class Encoder(Codec):
enc(bytecount, v)
def encodecollection(self, minor, items):
self.header(3, minor, len(items))
self.header(2, minor, len(items))
for i in items: self.append(i)
def encodestream(self, t, n, items):
tn = ((t & 3) << 2) | (n & 3)
self.header(0, 2, tn)
self.leadbyte(0, 2, tn)
for i in items: self.append(i)
self.header(0, 3, tn)
self.leadbyte(0, 0, 4)
def append(self, v):
if hasattr(v, '__preserve_on__'):
try:
placeholder = self.placeholders.get(v, None)
except TypeError: ## some types (e.g. list) yield 'unhashable type'
placeholder = None
if placeholder is not None:
self.header(0, 1, placeholder)
elif hasattr(v, '__preserve_on__'):
v.__preserve_on__(self)
elif v is False:
self.leadbyte(0, 0, 0)
@ -420,7 +502,7 @@ class Encoder(Codec):
self.buffer.extend(struct.pack('>d', v))
elif isinstance(v, numbers.Number):
if v >= -3 and v <= 12:
self.leadbyte(0, 1, v if v >= 0 else v + 16)
self.leadbyte(0, 3, v if v >= 0 else v + 16)
else:
self.encodeint(v)
elif isinstance(v, bytes):
@ -431,18 +513,18 @@ class Encoder(Codec):
self.header(1, 1, len(bs))
self.buffer.extend(bs)
elif isinstance(v, list):
self.encodecollection(0, v)
self.encodecollection(1, v)
elif isinstance(v, tuple):
self.encodecollection(0, v)
self.encodecollection(1, v)
elif isinstance(v, set):
self.encodecollection(1, v)
self.encodecollection(2, v)
elif isinstance(v, frozenset):
self.encodecollection(1, v)
self.encodecollection(2, v)
elif isinstance(v, dict):
self.encodecollection(2, list(dict_kvs(v)))
self.encodecollection(3, list(dict_kvs(v)))
else:
try:
i = iter(v)
except TypeError:
raise EncodeError('Cannot encode %r' % (v,))
self.encodestream(3, 0, i)
self.encodestream(2, 1, i)

View File

@ -32,7 +32,11 @@ def _varint(v):
return e.contents()
def _d(bs):
d = Decoder(bs)
d = Decoder(bs, placeholders={
0: Symbol('discard'),
1: Symbol('capture'),
2: Symbol('observe'),
})
return d.next()
_all_encoded = set()
@ -43,7 +47,11 @@ def tearDownModule():
print(_hex(bs))
def _e(v):
e = Encoder()
e = Encoder(placeholders={
Symbol('discard'): 0,
Symbol('capture'): 1,
Symbol('observe'): 2,
})
e.append(v)
bs = e.contents()
_all_encoded.add(bs)
@ -81,33 +89,34 @@ class CodecTests(unittest.TestCase):
self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3))
def test_shorts(self):
self._roundtrip(_R('capture', _R('discard')), _buf(0x91, 0x80))
self._roundtrip(_R('capture', _R('discard')), _buf(0x82, 0x11, 0x81, 0x10))
self._roundtrip(_R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))),
_buf(0xA1, 0xB3, 0x75, "speak", 0x80, 0x91, 0x80))
_buf(0x82, 0x12, 0x83, 0x75, "speak", 0x81, 0x10, 0x82, 0x11, 0x81, 0x10))
def test_simple_seq(self):
self._roundtrip([1,2,3,4], _buf(0xC4, 0x11, 0x12, 0x13, 0x14), back=(1,2,3,4))
self._roundtrip(SequenceStream([1,2,3,4]), _buf(0x2C, 0x11, 0x12, 0x13, 0x14, 0x3C),
self._roundtrip([1,2,3,4], _buf(0x94, 0x31, 0x32, 0x33, 0x34), back=(1,2,3,4))
self._roundtrip(SequenceStream([1,2,3,4]), _buf(0x29, 0x31, 0x32, 0x33, 0x34, 0x04),
back=(1,2,3,4))
self._roundtrip((-2,-1,0,1), _buf(0xC4, 0x1E, 0x1F, 0x10, 0x11))
self._roundtrip((-2,-1,0,1), _buf(0x94, 0x3E, 0x3F, 0x30, 0x31))
def test_str(self):
self._roundtrip(u'hello', _buf(0x55, 'hello'))
self._roundtrip(StringStream([b'he', b'llo']), _buf(0x25, 0x62, 'he', 0x63, 'llo', 0x35),
self._roundtrip(StringStream([b'he', b'llo']), _buf(0x25, 0x62, 'he', 0x63, 'llo', 0x04),
back=u'hello')
## TODO: error with zero-size chunks
self._roundtrip(StringStream([b'he', b'll', b'', b'', b'o']),
_buf(0x25, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x35),
_buf(0x25, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x04),
back=u'hello')
self._roundtrip(BinaryStream([b'he', b'll', b'', b'', b'o']),
_buf(0x26, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x36),
_buf(0x26, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x04),
back=b'hello')
self._roundtrip(SymbolStream([b'he', b'll', b'', b'', b'o']),
_buf(0x27, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x37),
_buf(0x27, 0x62, 'he', 0x62, 'll', 0x60, 0x60, 0x61, 'o', 0x04),
back=Symbol(u'hello'))
def test_mixed1(self):
self._roundtrip((u'hello', Symbol(u'there'), b'world', (), set(), True, False),
_buf(0xc7, 0x55, 'hello', 0x75, 'there', 0x65, 'world', 0xc0, 0xd0, 1, 0))
_buf(0x97, 0x55, 'hello', 0x75, 'there', 0x65, 'world', 0x90, 0xa0, 1, 0))
def test_signedinteger(self):
self._roundtrip(-257, _buf(0x42, 0xFE, 0xFF))
@ -118,12 +127,12 @@ class CodecTests(unittest.TestCase):
self._roundtrip(-128, _buf(0x41, 0x80))
self._roundtrip(-127, _buf(0x41, 0x81))
self._roundtrip(-4, _buf(0x41, 0xFC))
self._roundtrip(-3, _buf(0x1D))
self._roundtrip(-2, _buf(0x1E))
self._roundtrip(-1, _buf(0x1F))
self._roundtrip(0, _buf(0x10))
self._roundtrip(1, _buf(0x11))
self._roundtrip(12, _buf(0x1C))
self._roundtrip(-3, _buf(0x3D))
self._roundtrip(-2, _buf(0x3E))
self._roundtrip(-1, _buf(0x3F))
self._roundtrip(0, _buf(0x30))
self._roundtrip(1, _buf(0x31))
self._roundtrip(12, _buf(0x3C))
self._roundtrip(13, _buf(0x41, 0x0D))
self._roundtrip(127, _buf(0x41, 0x7F))
self._roundtrip(128, _buf(0x42, 0x00, 0x80))
@ -141,29 +150,13 @@ class CodecTests(unittest.TestCase):
self._roundtrip(-1.202e300, _buf(3, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26))
def test_badchunks(self):
self.assertEqual(_d(_buf(0x25, 0x61, 'a', 0x35)), u'a')
self.assertEqual(_d(_buf(0x26, 0x61, 'a', 0x36)), b'a')
self.assertEqual(_d(_buf(0x27, 0x61, 'a', 0x37)), Symbol(u'a'))
self.assertEqual(_d(_buf(0x25, 0x61, 'a', 0x04)), u'a')
self.assertEqual(_d(_buf(0x26, 0x61, 'a', 0x04)), b'a')
self.assertEqual(_d(_buf(0x27, 0x61, 'a', 0x04)), Symbol(u'a'))
for a in [0x25, 0x26, 0x27]:
for b in [0x51, 0x71]:
with self.assertRaises(DecodeError, msg='Unexpected non-binary chunk') as cm:
_d(_buf(a, b, 'a', 0x10+a))
def test_person(self):
self._roundtrip(Record((Symbol(u'titled'), Symbol(u'person'), 2, Symbol(u'thing'), 1),
[
101,
u'Blackwell',
_R(u'date', 1821, 2, 3),
u'Dr'
]),
_buf(0xB5, 0xC5, 0x76, 0x74, 0x69, 0x74, 0x6C, 0x65,
0x64, 0x76, 0x70, 0x65, 0x72, 0x73, 0x6F, 0x6E,
0x12, 0x75, 0x74, 0x68, 0x69, 0x6E, 0x67, 0x11,
0x41, 0x65, 0x59, 0x42, 0x6C, 0x61, 0x63, 0x6B,
0x77, 0x65, 0x6C, 0x6C, 0xB4, 0x74, 0x64, 0x61,
0x74, 0x65, 0x42, 0x07, 0x1D, 0x12, 0x13, 0x52,
0x44, 0x72))
_d(_buf(a, b, 'a', 0x04))
def test_dict(self):
self._roundtrip({ Symbol(u'a'): 1,
@ -171,17 +164,17 @@ class CodecTests(unittest.TestCase):
(1, 2, 3): b'c',
ImmutableDict({ Symbol(u'first-name'): u'Elizabeth', }):
{ Symbol(u'surname'): u'Blackwell' } },
_buf(0xE8,
0x71, "a", 0x11,
_buf(0xB8,
0x71, "a", 0x31,
0x51, "b", 0x01,
0xC3, 0x11, 0x12, 0x13, 0x61, "c",
0xE2, 0x7A, "first-name", 0x59, "Elizabeth",
0xE2, 0x77, "surname", 0x59, "Blackwell"),
0x93, 0x31, 0x32, 0x33, 0x61, "c",
0xB2, 0x7A, "first-name", 0x59, "Elizabeth",
0xB2, 0x77, "surname", 0x59, "Blackwell"),
nondeterministic = True)
def test_iterator_stream(self):
d = {u'a': 1, u'b': 2, u'c': 3}
r = r'2c(c2516.1.){3}3c'
r = r'29(92516.3.){3}04'
if hasattr(d, 'iteritems'):
# python 2
bs = _e(d.iteritems())
@ -194,17 +187,20 @@ class CodecTests(unittest.TestCase):
def test_long_sequence(self):
# Short enough to not need a varint:
self._roundtrip((False,) * 14, _buf(0xCE, b'\x00' * 14))
self._roundtrip((False,) * 14, _buf(0x9E, b'\x00' * 14))
# Varint-needing:
self._roundtrip((False,) * 15, _buf(0xCF, 0x0F, b'\x00' * 15))
self._roundtrip((False,) * 100, _buf(0xCF, 0x64, b'\x00' * 100))
self._roundtrip((False,) * 200, _buf(0xCF, 0xC8, 0x01, b'\x00' * 200))
self._roundtrip((False,) * 15, _buf(0x9F, 0x0F, b'\x00' * 15))
self._roundtrip((False,) * 100, _buf(0x9F, 0x64, b'\x00' * 100))
self._roundtrip((False,) * 200, _buf(0x9F, 0xC8, 0x01, b'\x00' * 200))
def test_format_c_twice(self):
self._roundtrip(SequenceStream([StringStream([b'abc']), StringStream([b'def'])]),
_buf(0x2C, 0x25, 0x63, 'abc', 0x35, 0x25, 0x63, 'def', 0x35, 0x3C),
_buf(0x29, 0x25, 0x63, 'abc', 0x04, 0x25, 0x63, 'def', 0x04, 0x04),
back=(u'abc', u'def'))
def test_common_test_suite(self):
self.fail('Common test suite needs to be implemented')
class RecordTests(unittest.TestCase):
def test_getters(self):
T = Record.makeConstructor('t', 'x y z')