Update python implementation

This commit is contained in:
Tony Garnock-Jones 2020-12-30 19:24:37 +01:00
parent 8459521db5
commit ca2276d268
4 changed files with 118 additions and 262 deletions

View File

@ -5,6 +5,3 @@ from .preserves import DecodeError, EncodeError, ShortPacket
from .preserves import Decoder, Encoder, decode, decode_with_annotations, encode
from .preserves import Annotated, is_annotated, strip_annotations, annotate
from .preserves import Stream, ValueStream, SequenceStream, SetStream, DictStream
from .preserves import BinaryStream, StringStream, SymbolStream

View File

@ -30,7 +30,7 @@ class Float(object):
return 'Float(' + repr(self.value) + ')'
def __preserve_on__(self, encoder):
encoder.leadbyte(0, 0, 2)
encoder.buffer.append(0x82)
encoder.buffer.extend(struct.pack('>f', self.value))
class Symbol(object):
@ -51,7 +51,8 @@ class Symbol(object):
def __preserve_on__(self, encoder):
bs = self.name.encode('utf-8')
encoder.header(1, 3, len(bs))
encoder.buffer.append(0xb3)
encoder.varint(len(bs))
encoder.buffer.extend(bs)
class Record(object):
@ -75,10 +76,11 @@ class Record(object):
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
def __preserve_on__(self, encoder):
encoder.header(2, 0, len(self.fields) + 1)
encoder.buffer.append(0xb4)
encoder.append(self.key)
for f in self.fields:
encoder.append(f)
encoder.buffer.append(0x84)
def __getitem__(self, index):
return self.fields[index]
@ -186,53 +188,6 @@ class ShortPacket(DecodeError): pass
class Codec(object): pass
class Stream(object):
def __init__(self, iterator):
self._iterator = iterator
def __preserve_on__(self, encoder):
arg = (self.major << 2) | self.minor
encoder.leadbyte(0, 2, arg)
self._emit(encoder)
encoder.leadbyte(0, 0, 4)
def _emit(self, encoder):
raise NotImplementedError('Should be implemented in subclasses')
class ValueStream(Stream):
major = 2
def _emit(self, encoder):
for v in self._iterator:
encoder.append(v)
class SequenceStream(ValueStream):
minor = 1
class SetStream(ValueStream):
minor = 2
class DictStream(ValueStream):
minor = 3
def _emit(self, encoder):
for (k, v) in self._iterator:
encoder.append(k)
encoder.append(v)
class BinaryStream(Stream):
major = 1
minor = 2
def _emit(self, encoder):
for chunk in self._iterator:
if not isinstance(chunk, bytes) or len(chunk) == 0:
raise EncodeError('Illegal chunk in BinaryStream %r' % (chunk,))
encoder.append(chunk)
class StringStream(BinaryStream):
minor = 1
class SymbolStream(BinaryStream):
minor = 3
inf = float('inf')
class Annotated(object):
@ -242,7 +197,7 @@ class Annotated(object):
def __preserve_on__(self, encoder):
for a in self.annotations:
encoder.header(0, 0, 5)
encoder.buffer.append(0x85)
encoder.append(a)
encoder.append(self.item)
@ -326,11 +281,6 @@ class Decoder(Codec):
self.index = end
return self.packet[start : end]
def wirelength(self, arg):
if arg < 15:
return arg
return self.varint()
def varint(self):
v = self.nextbyte()
if v < 128:
@ -338,66 +288,26 @@ class Decoder(Codec):
else:
return self.varint() * 128 + (v - 128)
def nextvalues(self, n):
result = []
for i in range(n):
result.append(self.next())
return result
def nextop(self):
b = self.nextbyte()
major = b >> 6
minor = (b >> 4) & 3
arg = b & 15
return (major, minor, arg)
def peekend(self):
matched = (self.nextbyte() == 4)
matched = (self.nextbyte() == 0x84)
if not matched:
self.index = self.index - 1
return matched
def binarystream(self, minor):
result = []
while not self.peekend():
chunk = strip_annotations(self.next())
if isinstance(chunk, bytes):
if len(chunk) > 0:
result.append(chunk)
else:
raise DecodeError('Empty binary chunks are forbidden')
else:
raise DecodeError('Unexpected non-binary chunk', chunk, isinstance(chunk, bytes), type(chunk))
return self.decodebinary(minor, b''.join(result))
def valuestream(self, minor):
def nextvalues(self):
result = []
while not self.peekend():
result.append(self.next())
return self.decodecompound(minor, result)
return result
def decodeint(self, bs):
if len(bs) == 0: return 0
acc = _ord(bs[0])
def nextint(self, n):
if n == 0: return 0
acc = self.nextbyte()
if acc & 0x80: acc = acc - 256
for b in bs[1:]:
acc = (acc << 8) | _ord(b)
for _i in range(n - 1):
acc = (acc << 8) | self.nextbyte()
return acc
def decodebinary(self, minor, bs):
if minor == 0: return self.decodeint(bs)
if minor == 1: return bs.decode('utf-8')
if minor == 2: return bs
if minor == 3: return Symbol(bs.decode('utf-8'))
def decodecompound(self, minor, vs):
if minor == 0:
if not vs: raise DecodeError('Too few elements in encoded record')
return Record(vs[0], vs[1:])
if minor == 1: return tuple(vs)
if minor == 2: return frozenset(vs)
if minor == 3: return ImmutableDict.from_kvs(vs)
def wrap(self, v):
return Annotated(v) if self.include_annotations else v
@ -407,40 +317,30 @@ class Decoder(Codec):
return v
def next(self):
while True: # we loop because we may need to consume an arbitrary number of no-ops
(major, minor, arg) = self.nextop()
if major == 0:
if minor == 0:
if arg == 0: return self.wrap(False)
if arg == 1: return self.wrap(True)
if arg == 2: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
if arg == 3: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
if arg == 4: raise DecodeError('Unexpected end-of-stream marker')
if arg == 5:
a = self.next()
v = self.next()
return self.unshift_annotation(a, v)
raise DecodeError('Invalid format A encoding')
elif minor == 1:
raise DecodeError('Invalid format A encoding')
elif minor == 2:
t = arg >> 2
n = arg & 3
if t == 1: return self.wrap(self.binarystream(n))
if t == 2: return self.wrap(self.valuestream(n))
raise DecodeError('Invalid format C start byte')
else: # minor == 3
return self.wrap(arg - 16 if arg > 12 else arg)
elif major == 1:
return self.wrap(self.decodebinary(minor, self.nextbytes(self.wirelength(arg))))
elif major == 2:
return self.wrap(self.decodecompound(minor, self.nextvalues(self.wirelength(arg))))
else: # major == 3
if minor == 3 and arg == 15:
# no-op.
continue
else:
raise DecodeError('Invalid lead byte (major 3)')
tag = self.nextbyte()
if tag == 0x80: return self.wrap(False)
if tag == 0x81: return self.wrap(True)
if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
if tag == 0x85:
a = self.next()
v = self.next()
return self.unshift_annotation(a, v)
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8'))
if tag == 0xb2: return self.wrap(self.nextbytes(self.varint()))
if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8')))
if tag == 0xb4:
vs = self.nextvalues()
if not vs: raise DecodeError('Too few elements in encoded record')
return self.wrap(Record(vs[0], vs[1:]))
if tag == 0xb5: return self.wrap(tuple(self.nextvalues()))
if tag == 0xb6: return self.wrap(frozenset(self.nextvalues()))
if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues()))
raise DecodeError('Invalid tag: ' + hex(tag))
def try_next(self):
start = self.index
@ -471,77 +371,65 @@ class Encoder(Codec):
self.buffer.append((v % 128) + 128)
self.varint(v // 128)
def leadbyte(self, major, minor, arg):
self.buffer.append(((major & 3) << 6) | ((minor & 3) << 4) | (arg & 15))
def header(self, major, minor, wirelength):
if wirelength < 15:
self.leadbyte(major, minor, wirelength)
else:
self.leadbyte(major, minor, 15)
self.varint(wirelength)
def encodeint(self, v):
bitcount = (~v if v < 0 else v).bit_length() + 1
bytecount = (bitcount + 7) // 8
self.header(1, 0, bytecount)
if bytecount <= 16:
self.buffer.append(0xa0 + bytecount - 1)
else:
self.buffer.append(0xb0)
self.varint(bytecount)
def enc(n,x):
if n > 0:
enc(n-1, x >> 8)
self.buffer.append(x & 255)
enc(bytecount, v)
def encodecollection(self, minor, items):
self.header(2, minor, len(items))
def encodevalues(self, tag, items):
self.buffer.append(0xb0 + tag)
for i in items: self.append(i)
self.buffer.append(0x84)
def encodestream(self, t, n, items):
tn = ((t & 3) << 2) | (n & 3)
self.leadbyte(0, 2, tn)
for i in items: self.append(i)
self.leadbyte(0, 0, 4)
def encodenoop(self):
self.leadbyte(3, 3, 15)
def encodebytes(self, tag, bs):
self.buffer.append(0xb0 + tag)
self.varint(len(bs))
self.buffer.extend(bs)
def append(self, v):
if hasattr(v, '__preserve_on__'):
v.__preserve_on__(self)
elif v is False:
self.leadbyte(0, 0, 0)
self.buffer.append(0x80)
elif v is True:
self.leadbyte(0, 0, 1)
self.buffer.append(0x81)
elif isinstance(v, float):
self.leadbyte(0, 0, 3)
self.buffer.append(0x83)
self.buffer.extend(struct.pack('>d', v))
elif isinstance(v, numbers.Number):
if v >= -3 and v <= 12:
self.leadbyte(0, 3, v if v >= 0 else v + 16)
self.buffer.append(0x90 + (v if v >= 0 else v + 16))
else:
self.encodeint(v)
elif isinstance(v, bytes):
self.header(1, 2, len(v))
self.buffer.extend(v)
self.encodebytes(2, v)
elif isinstance(v, basestring):
bs = v.encode('utf-8')
self.header(1, 1, len(bs))
self.buffer.extend(bs)
self.encodebytes(1, v.encode('utf-8'))
elif isinstance(v, list):
self.encodecollection(1, v)
self.encodevalues(5, v)
elif isinstance(v, tuple):
self.encodecollection(1, v)
self.encodevalues(5, v)
elif isinstance(v, set):
self.encodecollection(2, v)
self.encodevalues(6, v)
elif isinstance(v, frozenset):
self.encodecollection(2, v)
self.encodevalues(6, v)
elif isinstance(v, dict):
self.encodecollection(3, list(dict_kvs(v)))
self.encodevalues(7, list(dict_kvs(v)))
else:
try:
i = iter(v)
except TypeError:
raise EncodeError('Cannot encode %r' % (v,))
self.encodestream(2, 1, i)
self.encodevalues(5, i)
def encode(v):
e = Encoder()

View File

@ -70,68 +70,58 @@ class CodecTests(unittest.TestCase):
self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3))
def test_simple_seq(self):
self._roundtrip([1,2,3,4], _buf(0x94, 0x31, 0x32, 0x33, 0x34), back=(1,2,3,4))
self._roundtrip(SequenceStream([1,2,3,4]), _buf(0x29, 0x31, 0x32, 0x33, 0x34, 0x04),
back=(1,2,3,4))
self._roundtrip((-2,-1,0,1), _buf(0x94, 0x3E, 0x3F, 0x30, 0x31))
self._roundtrip([1,2,3,4], _buf(0xb5, 0x91, 0x92, 0x93, 0x94, 0x84), back=(1,2,3,4))
self._roundtrip(iter([1,2,3,4]),
_buf(0xb5, 0x91, 0x92, 0x93, 0x94, 0x84),
back=(1,2,3,4),
nondeterministic=True)
self._roundtrip((-2,-1,0,1), _buf(0xb5, 0x9E, 0x9F, 0x90, 0x91, 0x84))
def test_str(self):
self._roundtrip(u'hello', _buf(0x55, 'hello'))
self._roundtrip(StringStream([b'he', b'llo']), _buf(0x25, 0x62, 'he', 0x63, 'llo', 0x04),
back=u'hello')
self._roundtrip(StringStream([b'he', b'll', b'o']),
_buf(0x25, 0x62, 'he', 0x62, 'll', 0x61, 'o', 0x04),
back=u'hello')
self._roundtrip(BinaryStream([b'he', b'll', b'o']),
_buf(0x26, 0x62, 'he', 0x62, 'll', 0x61, 'o', 0x04),
back=b'hello')
self._roundtrip(SymbolStream([b'he', b'll', b'o']),
_buf(0x27, 0x62, 'he', 0x62, 'll', 0x61, 'o', 0x04),
back=Symbol(u'hello'))
self._roundtrip(u'hello', _buf(0xb1, 0x05, 'hello'))
def test_mixed1(self):
self._roundtrip((u'hello', Symbol(u'there'), b'world', (), set(), True, False),
_buf(0x97, 0x55, 'hello', 0x75, 'there', 0x65, 'world', 0x90, 0xa0, 1, 0))
_buf(0xb5,
0xb1, 0x05, 'hello',
0xb3, 0x05, 'there',
0xb2, 0x05, 'world',
0xb5, 0x84,
0xb6, 0x84,
0x81,
0x80,
0x84))
def test_signedinteger(self):
self._roundtrip(-257, _buf(0x42, 0xFE, 0xFF))
self._roundtrip(-256, _buf(0x42, 0xFF, 0x00))
self._roundtrip(-255, _buf(0x42, 0xFF, 0x01))
self._roundtrip(-254, _buf(0x42, 0xFF, 0x02))
self._roundtrip(-129, _buf(0x42, 0xFF, 0x7F))
self._roundtrip(-128, _buf(0x41, 0x80))
self._roundtrip(-127, _buf(0x41, 0x81))
self._roundtrip(-4, _buf(0x41, 0xFC))
self._roundtrip(-3, _buf(0x3D))
self._roundtrip(-2, _buf(0x3E))
self._roundtrip(-1, _buf(0x3F))
self._roundtrip(0, _buf(0x30))
self._roundtrip(1, _buf(0x31))
self._roundtrip(12, _buf(0x3C))
self._roundtrip(13, _buf(0x41, 0x0D))
self._roundtrip(127, _buf(0x41, 0x7F))
self._roundtrip(128, _buf(0x42, 0x00, 0x80))
self._roundtrip(255, _buf(0x42, 0x00, 0xFF))
self._roundtrip(256, _buf(0x42, 0x01, 0x00))
self._roundtrip(32767, _buf(0x42, 0x7F, 0xFF))
self._roundtrip(32768, _buf(0x43, 0x00, 0x80, 0x00))
self._roundtrip(65535, _buf(0x43, 0x00, 0xFF, 0xFF))
self._roundtrip(65536, _buf(0x43, 0x01, 0x00, 0x00))
self._roundtrip(131072, _buf(0x43, 0x02, 0x00, 0x00))
self._roundtrip(-257, _buf(0xa1, 0xFE, 0xFF))
self._roundtrip(-256, _buf(0xa1, 0xFF, 0x00))
self._roundtrip(-255, _buf(0xa1, 0xFF, 0x01))
self._roundtrip(-254, _buf(0xa1, 0xFF, 0x02))
self._roundtrip(-129, _buf(0xa1, 0xFF, 0x7F))
self._roundtrip(-128, _buf(0xa0, 0x80))
self._roundtrip(-127, _buf(0xa0, 0x81))
self._roundtrip(-4, _buf(0xa0, 0xFC))
self._roundtrip(-3, _buf(0x9D))
self._roundtrip(-2, _buf(0x9E))
self._roundtrip(-1, _buf(0x9F))
self._roundtrip(0, _buf(0x90))
self._roundtrip(1, _buf(0x91))
self._roundtrip(12, _buf(0x9C))
self._roundtrip(13, _buf(0xa0, 0x0D))
self._roundtrip(127, _buf(0xa0, 0x7F))
self._roundtrip(128, _buf(0xa1, 0x00, 0x80))
self._roundtrip(255, _buf(0xa1, 0x00, 0xFF))
self._roundtrip(256, _buf(0xa1, 0x01, 0x00))
self._roundtrip(32767, _buf(0xa1, 0x7F, 0xFF))
self._roundtrip(32768, _buf(0xa2, 0x00, 0x80, 0x00))
self._roundtrip(65535, _buf(0xa2, 0x00, 0xFF, 0xFF))
self._roundtrip(65536, _buf(0xa2, 0x01, 0x00, 0x00))
self._roundtrip(131072, _buf(0xa2, 0x02, 0x00, 0x00))
def test_floats(self):
self._roundtrip(Float(1.0), _buf(2, 0x3f, 0x80, 0, 0))
self._roundtrip(1.0, _buf(3, 0x3f, 0xf0, 0, 0, 0, 0, 0, 0))
self._roundtrip(-1.202e300, _buf(3, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26))
def test_badchunks(self):
self.assertEqual(_d(_buf(0x25, 0x61, 'a', 0x04)), u'a')
self.assertEqual(_d(_buf(0x26, 0x61, 'a', 0x04)), b'a')
self.assertEqual(_d(_buf(0x27, 0x61, 'a', 0x04)), Symbol(u'a'))
for a in [0x25, 0x26, 0x27]:
for b in [0x51, 0x71]:
with self.assertRaises(DecodeError, msg='Unexpected non-binary chunk') as cm:
_d(_buf(a, b, 'a', 0x04))
self._roundtrip(Float(1.0), _buf(0x82, 0x3f, 0x80, 0, 0))
self._roundtrip(1.0, _buf(0x83, 0x3f, 0xf0, 0, 0, 0, 0, 0, 0))
self._roundtrip(-1.202e300, _buf(0x83, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26))
def test_dict(self):
self._roundtrip({ Symbol(u'a'): 1,
@ -139,17 +129,18 @@ class CodecTests(unittest.TestCase):
(1, 2, 3): b'c',
ImmutableDict({ Symbol(u'first-name'): u'Elizabeth', }):
{ Symbol(u'surname'): u'Blackwell' } },
_buf(0xB8,
0x71, "a", 0x31,
0x51, "b", 0x01,
0x93, 0x31, 0x32, 0x33, 0x61, "c",
0xB2, 0x7A, "first-name", 0x59, "Elizabeth",
0xB2, 0x77, "surname", 0x59, "Blackwell"),
_buf(0xB7,
0xb3, 0x01, "a", 0x91,
0xb1, 0x01, "b", 0x81,
0xb5, 0x91, 0x92, 0x93, 0x84, 0xb2, 0x01, "c",
0xB7, 0xb3, 0x0A, "first-name", 0xb1, 0x09, "Elizabeth", 0x84,
0xB7, 0xb3, 0x07, "surname", 0xb1, 0x09, "Blackwell", 0x84,
0x84),
nondeterministic = True)
def test_iterator_stream(self):
d = {u'a': 1, u'b': 2, u'c': 3}
r = r'29(92516.3.){3}04'
r = r'b5(b5b1016.9.84){3}84'
if hasattr(d, 'iteritems'):
# python 2
bs = _e(d.iteritems())
@ -161,17 +152,10 @@ class CodecTests(unittest.TestCase):
self.assertEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)])
def test_long_sequence(self):
# Short enough to not need a varint:
self._roundtrip((False,) * 14, _buf(0x9E, b'\x00' * 14))
# Varint-needing:
self._roundtrip((False,) * 15, _buf(0x9F, 0x0F, b'\x00' * 15))
self._roundtrip((False,) * 100, _buf(0x9F, 0x64, b'\x00' * 100))
self._roundtrip((False,) * 200, _buf(0x9F, 0xC8, 0x01, b'\x00' * 200))
def test_format_c_twice(self):
self._roundtrip(SequenceStream([StringStream([b'abc']), StringStream([b'def'])]),
_buf(0x29, 0x25, 0x63, 'abc', 0x04, 0x25, 0x63, 'def', 0x04, 0x04),
back=(u'abc', u'def'))
self._roundtrip((False,) * 14, _buf(0xb5, b'\x80' * 14, 0x84))
self._roundtrip((False,) * 15, _buf(0xb5, b'\x80' * 15, 0x84))
self._roundtrip((False,) * 100, _buf(0xb5, b'\x80' * 100, 0x84))
self._roundtrip((False,) * 200, _buf(0xb5, b'\x80' * 200, 0x84))
def add_method(d, tName, fn):
if hasattr(fn, 'func_name'):
@ -196,17 +180,7 @@ expected_values = {
"back": _R('R', Symbol('f')) },
"annotation7": { "forward": annotate([], Symbol('a'), Symbol('b'), Symbol('c')),
"back": () },
"bytes1": { "forward": BinaryStream([b'he', b'll', b'o']), "back": b'hello' },
"list1": { "forward": SequenceStream([1, 2, 3, 4]), "back": (1, 2, 3, 4) },
"list2": { "forward": SequenceStream([ StringStream([b'abc']), StringStream([b'def']) ]),
"back": (u"abc", u"def") },
"list3": { "forward": SequenceStream([[u"a", 1], [u"b", 2], [u"c", 3]]),
"back": ((u"a", 1), (u"b", 2), (u"c", 3)) },
"record2": { "value": _R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))) },
"string0a": { "forward": StringStream([]), "back": u'' },
"string1": { "forward": StringStream([b'he', b'll', b'o']), "back": u'hello' },
"string2": { "forward": StringStream([b'he', b'llo']), "back": u'hello' },
"symbol1": { "forward": SymbolStream([b'he', b'll', b'o']), "back": Symbol('hello') },
}
def get_expected_values(tName, textForm):
@ -237,7 +211,6 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm):
add_method(d, tName, test_back_ann)
if variant not in ['decode', 'nondeterministic']:
add_method(d, tName, test_encode)
if variant not in ['decode', 'nondeterministic', 'streaming']:
add_method(d, tName, test_encode_ann)
def install_exn_test(d, tName, bs, check_proc):
@ -264,8 +237,6 @@ class CommonTestSuite(unittest.TestCase):
t = t0.peel()
if t.key == Symbol('Test'):
install_test(locals(), 'normal', tName, t[0].strip(), t[1])
elif t.key == Symbol('StreamingTest'):
install_test(locals(), 'streaming', tName, t[0].strip(), t[1])
elif t.key == Symbol('NondeterministicTest'):
install_test(locals(), 'nondeterministic', tName, t[0].strip(), t[1])
elif t.key == Symbol('DecodeTest'):

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name="preserves",
version="0.3.0",
version="0.4.0",
author="Tony Garnock-Jones",
author_email="tonyg@leastfixedpoint.com",
license="Apache Software License",