Blue python

This commit is contained in:
Tony Garnock-Jones 2022-06-19 13:57:17 +02:00
parent 39e7d7da34
commit 912ad34ab7
10 changed files with 344 additions and 302 deletions

View File

@ -1,12 +1,14 @@
test: update-test-data test: update-data
python3 -m unittest discover -s tests python3 -m unittest discover -s tests
coverage: update-test-data coverage: update-data
python3-coverage run --branch -m unittest discover -s tests python3-coverage run --branch -m unittest discover -s tests
python3-coverage html python3-coverage html
update-test-data: update-data:
rsync ../../tests/samples.bin ../../tests/samples.pr tests rsync ../../tests/samples.bin ../../tests/samples.pr tests
rsync ../../path/path.bin preserves/path.prb
rsync ../../schema/schema.bin preserves/schema.prb
tag: tag:
git tag python-preserves@`python3 setup.py --version` git tag python-preserves@`python3 setup.py --version`
@ -23,5 +25,5 @@ clean:
publish: clean build publish: clean build
twine upload dist/* twine upload dist/*
build: build: update-data
python3 setup.py sdist bdist_wheel python3 setup.py sdist bdist_wheel

View File

@ -10,7 +10,7 @@ from .text import Parser, Formatter, parse, parse_with_annotations, stringify
from .merge import merge from .merge import merge
from . import fold, compare from . import fold, compare, iolist
loads = parse loads = parse
dumps = stringify dumps = stringify

View File

@ -4,132 +4,135 @@ import struct
from .values import * from .values import *
from .error import * from .error import *
from .compat import basestring_, ord_ from .compat import basestring_, ord_
from . import iolist
class BinaryCodec(object): pass class BinaryCodec(object): pass
class Decoder(BinaryCodec): class Decoder(BinaryCodec):
def __init__(self, packet=b'', include_annotations=False, decode_embedded=lambda x: x): def __init__(self, *, include_annotations=False, decode_embedded=lambda x: x):
super(Decoder, self).__init__()
self.packet = packet
self.index = 0
self.include_annotations = include_annotations self.include_annotations = include_annotations
self.decode_embedded = decode_embedded self.decode_embedded = decode_embedded
def extend(self, data): def next(self, packet):
self.packet = self.packet[self.index:] + data if not packet: raise ShortPacket('Short packet')
self.index = 0 if not isinstance(packet, memoryview): packet = memoryview(packet)
tag = packet[0]
packet = packet[1:]
if tag == 0xA0: return self.wrap(False)
if tag == 0xA1: return self.wrap(True)
if tag == 0xA2:
if len(packet) == 4: return self.wrap(Float(struct.unpack('>f', packet)[0]))
if len(packet) == 8: return self.wrap(struct.unpack('>d', packet)[0])
raise DecodeError('Unsupported floating-point size ' + str(len(packet)))
if tag == 0xA3: return self.wrap(decode_int(packet))
if tag == 0xA4: return self.wrap(bytes(packet[:-1]).decode('utf-8'))
if tag == 0xA5: return self.wrap(bytes(packet))
if tag == 0xA6: return self.wrap(Symbol(bytes(packet).decode('utf-8')))
if tag == 0xA7:
vs = self.nextvalues(packet)
if not vs: raise DecodeError('Too few elements in encoded record')
return self.wrap(Record(vs[0], vs[1:]))
if tag == 0xA8: return self.wrap(tuple(self.nextvalues(packet)))
if tag == 0xA9: return self.wrap(frozenset(self.nextvalues(packet)))
if tag == 0xAA: return self.wrap(ImmutableDict.from_kvs(self.nextvalues(packet)))
if tag == 0xAB:
if self.decode_embedded is None:
raise DecodeError('No decode_embedded function supplied')
return self.wrap(Embedded(self.decode_embedded(self.next(packet))))
if tag == 0xBF:
if self.include_annotations:
vs = self.nextvalues(packet)
if not vs: raise DecodeError('No elements in annotation')
vs[0].annotations.extend(vs[1:])
return vs[0]
else:
e = self.nextitem(packet)
if e is None: raise DecodeError('No elements in annotation')
return e[0]
raise DecodeError('Invalid tag: ' + hex(tag))
def nextbyte(self): def nextvalues(self, packet):
if self.index >= len(self.packet): vs = []
raise ShortPacket('Short packet') while True:
self.index = self.index + 1 e = self.nextitem(packet)
return ord_(self.packet[self.index - 1]) if e is None: return vs
vs.append(e[0])
packet = e[1]
def nextbytes(self, n): def nextitem(self, packet):
start = self.index if not packet: return None
end = start + n (count, i) = decode_varint(packet)
if end > len(self.packet): item = packet[i:i+count]
raise ShortPacket('Short packet') packet = packet[i+count:]
self.index = end return (self.next(item), packet)
return self.packet[start : end]
def varint(self):
v = self.nextbyte()
if v < 128:
return v
else:
return self.varint() * 128 + (v - 128)
def peekend(self):
matched = (self.nextbyte() == 0x84)
if not matched:
self.index = self.index - 1
return matched
def nextvalues(self):
result = []
while not self.peekend():
result.append(self.next())
return result
def nextint(self, n):
if n == 0: return 0
acc = self.nextbyte()
if acc & 0x80: acc = acc - 256
for _i in range(n - 1):
acc = (acc << 8) | self.nextbyte()
return acc
def wrap(self, v): def wrap(self, v):
return Annotated(v) if self.include_annotations else v return Annotated(v) if self.include_annotations else v
def unshift_annotation(self, a, v): def try_next(self, packet):
if self.include_annotations:
v.annotations.insert(0, a)
return v
def next(self):
tag = self.nextbyte()
if tag == 0x80: return self.wrap(False)
if tag == 0x81: return self.wrap(True)
if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
if tag == 0x85:
a = self.next()
v = self.next()
return self.unshift_annotation(a, v)
if tag == 0x86:
if self.decode_embedded is None:
raise DecodeError('No decode_embedded function supplied')
return self.wrap(Embedded(self.decode_embedded(self.next())))
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8'))
if tag == 0xb2: return self.wrap(self.nextbytes(self.varint()))
if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8')))
if tag == 0xb4:
vs = self.nextvalues()
if not vs: raise DecodeError('Too few elements in encoded record')
return self.wrap(Record(vs[0], vs[1:]))
if tag == 0xb5: return self.wrap(tuple(self.nextvalues()))
if tag == 0xb6: return self.wrap(frozenset(self.nextvalues()))
if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues()))
raise DecodeError('Invalid tag: ' + hex(tag))
def try_next(self):
start = self.index
try: try:
return self.next() return self.next(packet)
except ShortPacket: except ShortPacket:
self.index = start
return None return None
def decode_varint(packet):
count = 0
for (i, b) in enumerate(packet):
if b & 0x80: return ((count << 7) + (b - 0x80), i + 1)
count = (count << 7) + b
raise ShortPacket('Short packet (incomplete length)')
def decode_int(packet):
if not packet: return 0
acc = packet[0]
if acc & 0x80: acc = acc - 256
for i in range(1, len(packet)):
acc = (acc << 8) | packet[i]
return acc
class StreamDecoder(object):
def __init__(self, initial_packet, decoder = None):
self.decoder = decoder or Decoder()
if not initial_packet:
raise DecodeError('Empty initial packet in StreamDecoder')
if initial_packet[0] != 0xA8:
raise DecodeError('Initial stream packet is not a Sequence')
self.buffer = memoryview(initial_packet[1:])
def extend(self, data):
self.buffer = memoryview(bytes(self.buffer) + data)
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
v = self.try_next() try:
if v is None: e = self.decoder.next(self.buffer)
if e is None: raise StopIteration
self.buffer = e[1]
return e[0]
except ShortPacket:
raise StopIteration raise StopIteration
return v
def decode(bs, **kwargs): def decode(bs, **kwargs):
return Decoder(packet=bs, **kwargs).next() return Decoder(**kwargs).next(bs)
def decode_with_annotations(bs, **kwargs): def decode_with_annotations(bs, **kwargs):
return Decoder(packet=bs, include_annotations=True, **kwargs).next() return Decoder(include_annotations=True, **kwargs).next(bs)
class Encoder(BinaryCodec): class Encoder(BinaryCodec):
def __init__(self, encode_embedded=lambda x: x, canonicalize=False): def __init__(self, *, encode_embedded=lambda x: x, canonicalize=False, include_annotations=None):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.buffer = bytearray() self.buffer = None
self._encode_embedded = encode_embedded self._encode_embedded = encode_embedded
self._canonicalize = canonicalize self._canonicalize = canonicalize
if include_annotations is None:
self.include_annotations = not canonicalize
else:
self.include_annotations = include_annotations
def reset(self): def reset(self):
self.buffer = bytearray() self.buffer = None
def encode_embedded(self, v): def encode_embedded(self, v):
if self._encode_embedded is None: if self._encode_embedded is None:
@ -137,118 +140,79 @@ class Encoder(BinaryCodec):
return self._encode_embedded(v) return self._encode_embedded(v)
def contents(self): def contents(self):
return bytes(self.buffer) return iolist.bytes(self.buffer)
def varint(self, v): def lengthprefixed(self, encoded):
if v < 128: encoded = iolist.counted(encoded)
self.buffer.append(v) return [encode_varint(iolist.len(encoded)), encoded]
else:
self.buffer.append((v % 128) + 128)
self.varint(v // 128)
def encodeint(self, v): def encodeditem(self, v):
bitcount = (~v if v < 0 else v).bit_length() + 1 return self.lengthprefixed(self.encoded_iolist(v))
bytecount = (bitcount + 7) // 8
if bytecount <= 16:
self.buffer.append(0xa0 + bytecount - 1)
else:
self.buffer.append(0xb0)
self.varint(bytecount)
def enc(n,x):
if n > 0:
enc(n-1, x >> 8)
self.buffer.append(x & 255)
enc(bytecount, v)
def encodevalues(self, tag, items): def encodedvalues(self, vs):
self.buffer.append(0xb0 + tag) return [self.encodeditem(v) for v in vs]
for i in items: self.append(i)
self.buffer.append(0x84)
def encodebytes(self, tag, bs): def encoded(self, v):
self.buffer.append(0xb0 + tag) return iolist.bytes(self.encoded_iolist(v))
self.varint(len(bs))
self.buffer.extend(bs)
def encodeset(self, v): def encoded_iolist(self, v):
if not self._canonicalize:
self.encodevalues(6, v)
else:
c = Canonicalizer(self._encode_embedded)
for i in v: c.entry([i])
c.emit_entries(self, 6)
def encodedict(self, v):
if not self._canonicalize:
self.encodevalues(7, list(dict_kvs(v)))
else:
c = Canonicalizer(self._encode_embedded)
for (kk, vv) in v.items(): c.entry([kk, vv])
c.emit_entries(self, 7)
def append(self, v):
v = preserve(v) v = preserve(v)
if hasattr(v, '__preserve_write_binary__'): if hasattr(v, '__preserve_encoded__'): return v.__preserve_encoded__(self)
v.__preserve_write_binary__(self) if v is False: return 0xA0
elif v is False: if v is True: return 0xA1
self.buffer.append(0x80) if isinstance(v, float): return [0xA2, struct.pack('>d', v)]
elif v is True: if isinstance(v, numbers.Number): return [0xA3, encode_int(v)]
self.buffer.append(0x81) if isinstance(v, bytes): return [0xA5, v]
elif isinstance(v, float): if isinstance(v, basestring_): return [0xA4, v.encode('utf-8'), 0]
self.buffer.append(0x83) if isinstance(v, list): return [0xA8, self.encodedvalues(v)]
self.buffer.extend(struct.pack('>d', v)) if isinstance(v, tuple): return [0xA8, self.encodedvalues(v)]
elif isinstance(v, numbers.Number): if isinstance(v, set) or isinstance(v, frozenset):
if v >= -3 and v <= 12: if self._canonicalize:
self.buffer.append(0x90 + (v if v >= 0 else v + 16)) return [0xA9, [self.encodeditem(i)
for (_c, i) in sorted((canonicalize(i), i) for i in v)]]
else: else:
self.encodeint(v) return [0xA9, self.encodedvalues(v)]
elif isinstance(v, bytes): if isinstance(v, dict):
self.encodebytes(2, v) if self._canonicalize:
elif isinstance(v, basestring_): return [0xAA, [[self.encodeditem(k), self.encodeditem(v)]
self.encodebytes(1, v.encode('utf-8')) for (_c, k, v) in sorted((canonicalize(k), k, v)
elif isinstance(v, list): for (k, v) in v.items())]]
self.encodevalues(5, v)
elif isinstance(v, tuple):
self.encodevalues(5, v)
elif isinstance(v, set):
self.encodeset(v)
elif isinstance(v, frozenset):
self.encodeset(v)
elif isinstance(v, dict):
self.encodedict(v)
else:
try:
i = iter(v)
except TypeError:
i = None
if i is None:
self.cannot_encode(v)
else: else:
self.encodevalues(5, i) return [0xAA, [[self.encodeditem(k), self.encodeditem(v)] for (k, v) in v.items()]]
try:
i = iter(v)
except TypeError:
i = None
if i is not None:
return [0xA8, self.encodedvalues(i)]
self.cannot_encode(v)
def cannot_encode(self, v): def cannot_encode(self, v):
raise TypeError('Cannot preserves-encode: ' + repr(v)) raise TypeError('Cannot preserves-encode: ' + repr(v))
class Canonicalizer: def encode_varint(n):
def __init__(self, encode_embedded): L = (n & 127) | 128
self.encoder = Encoder(encode_embedded, canonicalize=True) n = n >> 7
self.entries = [] while n > 0:
L = [n & 127, L]
n = n >> 7
return L
def entry(self, pieces): def encode_int(v):
for piece in pieces: self.encoder.append(piece) if v == 0: return None
entry = self.encoder.contents() if v == -1: return 255
self.encoder.reset()
self.entries.append(entry)
def emit_entries(self, outer_encoder, tag): bitcount = (~v if v < 0 else v).bit_length() + 1
outer_encoder.buffer.append(0xb0 + tag) bytecount = (bitcount + 7) // 8
for e in sorted(self.entries): outer_encoder.buffer.extend(e)
outer_encoder.buffer.append(0x84) D = None
for _i in range(bytecount):
D = [v & 255, D]
v = v >> 8
return D
def encode(v, **kwargs): def encode(v, **kwargs):
e = Encoder(**kwargs) return Encoder(**kwargs).encoded(v)
e.append(v)
return e.contents()
def canonicalize(v, **kwargs): def canonicalize(v, **kwargs):
return encode(v, canonicalize=True, **kwargs) return encode(v, canonicalize=True, **kwargs)

View File

@ -0,0 +1,72 @@
# An iolist is one of
# - None
# - a list of iolists
# - a CountedIOList
# - a bytes
# - a number i, 0 <= i < 256
class CountedIOList:
def __init__(self, i):
self.value = i
self.length = len(i)
def counted(i):
if isinstance(i, CountedIOList): return i
return CountedIOList(i)
def withbyte(i, b):
if i is None:
return b
elif isinstance(i, list):
i.append(b)
return i
else:
return [i, b]
_len = len
def join(*iolists):
if _len(iolists) == 0:
return None
if _len(iolists) == 1:
return iolists[0]
return list(iolists)
def len(i):
if i is None:
return 0
if isinstance(i, int):
return 1
if isinstance(i, _bytes):
return _len(i)
if isinstance(i, list):
acc = 0
for j in i:
acc = acc + len(j)
return acc
if isinstance(i, CountedIOList):
return i.length
raise ValueError('invalid iolist: ' + repr(i) + ' ' + repr(type(i)))
_bytes = bytes
def bytes(i):
buffer = bytearray(len(i))
def fill(i, offset):
while isinstance(i, CountedIOList):
i = i.value
if i is None:
return offset
if isinstance(i, int):
buffer[offset] = i
return offset + 1
if isinstance(i, _bytes):
buffer[offset:offset+_len(i)] = i
return offset + _len(i)
if isinstance(i, list):
for j in i:
offset = fill(j, offset)
return offset
raise ValueError('invalid iolist')
fill(i, 0)
return _bytes(buffer)

View File

@ -459,7 +459,7 @@ class Compiler:
def load(self, filename): def load(self, filename):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
x = Decoder(f.read()).next() x = Decoder().next(f.read())
if x.key == SCHEMA: if x.key == SCHEMA:
self.load_schema((Symbol(filename.stem),), x) self.load_schema((Symbol(filename.stem),), x)
elif x.key == BUNDLE: elif x.key == BUNDLE:
@ -500,7 +500,7 @@ meta = load_schema_file(__metaschema_filename).schema
if __name__ == '__main__': if __name__ == '__main__':
with open(__metaschema_filename, 'rb') as f: with open(__metaschema_filename, 'rb') as f:
x = Decoder(f.read()).next() x = Decoder().next(f.read())
print(meta.Schema.decode(x)) print(meta.Schema.decode(x))
print(preserve(meta.Schema.decode(x))) print(preserve(meta.Schema.decode(x)))
assert preserve(meta.Schema.decode(x)) == x assert preserve(meta.Schema.decode(x)) == x
@ -516,7 +516,7 @@ if __name__ == '__main__':
path_bin_filename = pathlib.Path(__file__).parent / 'path.prb' path_bin_filename = pathlib.Path(__file__).parent / 'path.prb'
path = load_schema_file(path_bin_filename).path path = load_schema_file(path_bin_filename).path
with open(path_bin_filename, 'rb') as f: with open(path_bin_filename, 'rb') as f:
x = Decoder(f.read()).next() x = Decoder().next(f.read())
print(meta.Schema.decode(x)) print(meta.Schema.decode(x))
assert meta.Schema.decode(x) == meta.Schema.decode(x) assert meta.Schema.decode(x) == meta.Schema.decode(x)
assert preserve(meta.Schema.decode(x)) == x assert preserve(meta.Schema.decode(x)) == x

View File

@ -265,7 +265,7 @@ class Parser(TextCodec):
bs_val = bs_val.item bs_val = bs_val.item
if not isinstance(bs_val, bytes): if not isinstance(bs_val, bytes):
raise DecodeError('ByteString must follow #=') raise DecodeError('ByteString must follow #=')
return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next()) return self.wrap(Decoder(include_annotations=self.include_annotations).next(bs_val))
if c == '!': if c == '!':
if self.parse_embedded is None: if self.parse_embedded is None:
raise DecodeError('No parse_embedded function supplied') raise DecodeError('No parse_embedded function supplied')

View File

@ -27,9 +27,8 @@ class Float(object):
def __repr__(self): def __repr__(self):
return 'Float(' + repr(self.value) + ')' return 'Float(' + repr(self.value) + ')'
def __preserve_write_binary__(self, encoder): def __preserve_encoded__(self, encoder):
encoder.buffer.append(0x82) return [0xA2, struct.pack('>f', self.value)]
encoder.buffer.extend(struct.pack('>f', self.value))
def __preserve_write_text__(self, formatter): def __preserve_write_text__(self, formatter):
formatter.chunks.append(repr(self.value) + 'f') formatter.chunks.append(repr(self.value) + 'f')
@ -66,11 +65,8 @@ class Symbol(object):
def __repr__(self): def __repr__(self):
return '#' + self.name return '#' + self.name
def __preserve_write_binary__(self, encoder): def __preserve_encoded__(self, encoder):
bs = self.name.encode('utf-8') return [0xA6, self.name.encode('utf-8')]
encoder.buffer.append(0xb3)
encoder.varint(len(bs))
encoder.buffer.extend(bs)
def __preserve_write_text__(self, formatter): def __preserve_write_text__(self, formatter):
if RAW_SYMBOL_RE.match(self.name): if RAW_SYMBOL_RE.match(self.name):
@ -103,12 +99,10 @@ class Record(object):
def __repr__(self): def __repr__(self):
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')' return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
def __preserve_write_binary__(self, encoder): def __preserve_encoded__(self, encoder):
encoder.buffer.append(0xb4) return [0xA7,
encoder.append(self.key) encoder.encodeditem(self.key),
for f in self.fields: encoder.encodedvalues(self.fields)]
encoder.append(f)
encoder.buffer.append(0x84)
def __preserve_write_text__(self, formatter): def __preserve_write_text__(self, formatter):
formatter.chunks.append('<') formatter.chunks.append('<')
@ -226,11 +220,13 @@ class Annotated(object):
self.annotations = [] self.annotations = []
self.item = item self.item = item
def __preserve_write_binary__(self, encoder): def __preserve_encoded__(self, encoder):
for a in self.annotations: if self.annotations and encoder.include_annotations:
encoder.buffer.append(0x85) return [0xBF,
encoder.append(a) encoder.encodeditem(self.item),
encoder.append(self.item) encoder.encodedvalues(self.annotations)]
else:
return encoder.encoded_iolist(self.item)
def __preserve_write_text__(self, formatter): def __preserve_write_text__(self, formatter):
for a in self.annotations: for a in self.annotations:
@ -314,9 +310,8 @@ class Embedded:
def __repr__(self): def __repr__(self):
return '#!%r' % (self.embeddedValue,) return '#!%r' % (self.embeddedValue,)
def __preserve_write_binary__(self, encoder): def __preserve_encoded__(self, encoder):
encoder.buffer.append(0x86) return [0xAB, encoder.encoded_iolist(encoder.encode_embedded(self.embeddedValue))]
encoder.append(encoder.encode_embedded(self.embeddedValue))
def __preserve_write_text__(self, formatter): def __preserve_write_text__(self, formatter):
formatter.chunks.append('#!') formatter.chunks.append('#!')

View File

@ -36,9 +36,7 @@ def _buf(*args):
return result return result
def _varint(v): def _varint(v):
e = Encoder() return iolist.bytes(binary.encode_varint(v))
e.varint(v)
return e.contents()
def _d(bs): def _d(bs):
return decode(bs) return decode(bs)
@ -60,76 +58,81 @@ class BinaryCodecTests(unittest.TestCase):
self.assertEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected))) self.assertEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected)))
def test_decode_varint(self): def test_decode_varint(self):
with self.assertRaises(DecodeError): with self.assertRaises(ShortPacket):
Decoder(_buf()).varint() binary.decode_varint(_buf())
self.assertEqual(Decoder(_buf(0)).varint(), 0) def dv(bs):
self.assertEqual(Decoder(_buf(10)).varint(), 10) (n, s) = binary.decode_varint(bs)
self.assertEqual(Decoder(_buf(100)).varint(), 100) self.assertEqual(s, len(bs))
self.assertEqual(Decoder(_buf(200, 1)).varint(), 200) return n
self.assertEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300) self.assertEqual(dv(_buf(128)), 0)
self.assertEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000) self.assertEqual(dv(_buf(138)), 10)
self.assertEqual(dv(_buf(228)), 100)
self.assertEqual(dv(_buf(1, 200)), 200)
self.assertEqual(dv(_buf(0b00000010, 0b10101100)), 300)
self.assertEqual(dv(_buf(3, 92, 107, 20, 128)), 1000000000)
def test_encode_varint(self): def test_encode_varint(self):
self.assertEqual(_varint(0), _buf(0)) self.assertEqual(_varint(0), _buf(128))
self.assertEqual(_varint(10), _buf(10)) self.assertEqual(_varint(10), _buf(138))
self.assertEqual(_varint(100), _buf(100)) self.assertEqual(_varint(100), _buf(228))
self.assertEqual(_varint(200), _buf(200, 1)) self.assertEqual(_varint(200), _buf(1, 200))
self.assertEqual(_varint(300), _buf(0b10101100, 0b00000010)) self.assertEqual(_varint(300), _buf(0b00000010, 0b10101100))
self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3)) self.assertEqual(_varint(1000000000), _buf(3, 92, 107, 20, 128))
def test_simple_seq(self): def test_simple_seq(self):
self._roundtrip([1,2,3,4], _buf(0xb5, 0x91, 0x92, 0x93, 0x94, 0x84), back=(1,2,3,4)) b1234 = _buf(0xa8, 0x82, 0xa3, 0x01, 0x82, 0xa3, 0x02, 0x82, 0xa3, 0x03, 0x82, 0xa3, 0x04)
self._roundtrip(iter([1,2,3,4]), self._roundtrip([1,2,3,4], b1234, back=(1,2,3,4))
_buf(0xb5, 0x91, 0x92, 0x93, 0x94, 0x84), self._roundtrip(iter([1,2,3,4]), b1234, back=(1,2,3,4), nondeterministic=True)
back=(1,2,3,4), self._roundtrip((-2,-1,0,1), _buf(0xa8,
nondeterministic=True) 0x82, 0xa3, 0xfe,
self._roundtrip((-2,-1,0,1), _buf(0xb5, 0x9E, 0x9F, 0x90, 0x91, 0x84)) 0x82, 0xa3, 0xff,
0x81, 0xa3,
0x82, 0xa3, 0x01))
def test_str(self): def test_str(self):
self._roundtrip(u'hello', _buf(0xb1, 0x05, 'hello')) self._roundtrip(u'hello', _buf(0xa4, 'hello', 0))
def test_mixed1(self): def test_mixed1(self):
self._roundtrip((u'hello', Symbol(u'there'), b'world', (), set(), True, False), self._roundtrip((u'hello', Symbol(u'there'), b'world', (), set(), True, False),
_buf(0xb5, _buf(0xa8,
0xb1, 0x05, 'hello', 0x87, 0xa4, 'hello', 0,
0xb3, 0x05, 'there', 0x86, 0xa6, 'there',
0xb2, 0x05, 'world', 0x86, 0xa5, 'world',
0xb5, 0x84, 0x81, 0xa8,
0xb6, 0x84, 0x81, 0xa9,
0x81, 0x81, 0xa1,
0x80, 0x81, 0xa0))
0x84))
def test_signedinteger(self): def test_signedinteger(self):
self._roundtrip(-257, _buf(0xa1, 0xFE, 0xFF)) self._roundtrip(-257, _buf(0xa3, 0xFE, 0xFF))
self._roundtrip(-256, _buf(0xa1, 0xFF, 0x00)) self._roundtrip(-256, _buf(0xa3, 0xFF, 0x00))
self._roundtrip(-255, _buf(0xa1, 0xFF, 0x01)) self._roundtrip(-255, _buf(0xa3, 0xFF, 0x01))
self._roundtrip(-254, _buf(0xa1, 0xFF, 0x02)) self._roundtrip(-254, _buf(0xa3, 0xFF, 0x02))
self._roundtrip(-129, _buf(0xa1, 0xFF, 0x7F)) self._roundtrip(-129, _buf(0xa3, 0xFF, 0x7F))
self._roundtrip(-128, _buf(0xa0, 0x80)) self._roundtrip(-128, _buf(0xa3, 0x80))
self._roundtrip(-127, _buf(0xa0, 0x81)) self._roundtrip(-127, _buf(0xa3, 0x81))
self._roundtrip(-4, _buf(0xa0, 0xFC)) self._roundtrip(-4, _buf(0xa3, 0xFC))
self._roundtrip(-3, _buf(0x9D)) self._roundtrip(-3, _buf(0xa3, 0xFD))
self._roundtrip(-2, _buf(0x9E)) self._roundtrip(-2, _buf(0xa3, 0xFE))
self._roundtrip(-1, _buf(0x9F)) self._roundtrip(-1, _buf(0xa3, 0xFF))
self._roundtrip(0, _buf(0x90)) self._roundtrip(0, _buf(0xa3))
self._roundtrip(1, _buf(0x91)) self._roundtrip(1, _buf(0xa3, 0x01))
self._roundtrip(12, _buf(0x9C)) self._roundtrip(12, _buf(0xa3, 0x0C))
self._roundtrip(13, _buf(0xa0, 0x0D)) self._roundtrip(13, _buf(0xa3, 0x0D))
self._roundtrip(127, _buf(0xa0, 0x7F)) self._roundtrip(127, _buf(0xa3, 0x7F))
self._roundtrip(128, _buf(0xa1, 0x00, 0x80)) self._roundtrip(128, _buf(0xa3, 0x00, 0x80))
self._roundtrip(255, _buf(0xa1, 0x00, 0xFF)) self._roundtrip(255, _buf(0xa3, 0x00, 0xFF))
self._roundtrip(256, _buf(0xa1, 0x01, 0x00)) self._roundtrip(256, _buf(0xa3, 0x01, 0x00))
self._roundtrip(32767, _buf(0xa1, 0x7F, 0xFF)) self._roundtrip(32767, _buf(0xa3, 0x7F, 0xFF))
self._roundtrip(32768, _buf(0xa2, 0x00, 0x80, 0x00)) self._roundtrip(32768, _buf(0xa3, 0x00, 0x80, 0x00))
self._roundtrip(65535, _buf(0xa2, 0x00, 0xFF, 0xFF)) self._roundtrip(65535, _buf(0xa3, 0x00, 0xFF, 0xFF))
self._roundtrip(65536, _buf(0xa2, 0x01, 0x00, 0x00)) self._roundtrip(65536, _buf(0xa3, 0x01, 0x00, 0x00))
self._roundtrip(131072, _buf(0xa2, 0x02, 0x00, 0x00)) self._roundtrip(131072, _buf(0xa3, 0x02, 0x00, 0x00))
def test_floats(self): def test_floats(self):
self._roundtrip(Float(1.0), _buf(0x82, 0x3f, 0x80, 0, 0)) self._roundtrip(Float(1.0), _buf(0xa2, 0x3f, 0x80, 0, 0))
self._roundtrip(1.0, _buf(0x83, 0x3f, 0xf0, 0, 0, 0, 0, 0, 0)) self._roundtrip(1.0, _buf(0xa2, 0x3f, 0xf0, 0, 0, 0, 0, 0, 0))
self._roundtrip(-1.202e300, _buf(0x83, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26)) self._roundtrip(-1.202e300, _buf(0xa2, 0xfe, 0x3c, 0xb7, 0xb7, 0x59, 0xbf, 0x04, 0x26))
def test_dict(self): def test_dict(self):
self._roundtrip({ Symbol(u'a'): 1, self._roundtrip({ Symbol(u'a'): 1,
@ -137,18 +140,17 @@ class BinaryCodecTests(unittest.TestCase):
(1, 2, 3): b'c', (1, 2, 3): b'c',
ImmutableDict({ Symbol(u'first-name'): u'Elizabeth', }): ImmutableDict({ Symbol(u'first-name'): u'Elizabeth', }):
{ Symbol(u'surname'): u'Blackwell' } }, { Symbol(u'surname'): u'Blackwell' } },
_buf(0xB7, _buf(0xaa,
0xb3, 0x01, "a", 0x91, 0x82, 0xa6, "a", 0x82, 0xa3, 0x01,
0xb1, 0x01, "b", 0x81, 0x83, 0xa4, "b", 0, 0x81, 0xa1,
0xb5, 0x91, 0x92, 0x93, 0x84, 0xb2, 0x01, "c", 0x8a, 0xa8, 0x82,0xa3,0x01, 0x82,0xa3,0x02, 0x82,0xa3,0x03, 0x82, 0xa5, "c",
0xB7, 0xb3, 0x0A, "first-name", 0xb1, 0x09, "Elizabeth", 0x84, 0x99, 0xaa, 0x8b, 0xa6, "first-name", 0x8b, 0xa4, "Elizabeth", 0,
0xB7, 0xb3, 0x07, "surname", 0xb1, 0x09, "Blackwell", 0x84, 0x96, 0xaa, 0x88, 0xa6, "surname", 0x8b, 0xa4, "Blackwell", 0),
0x84),
nondeterministic = True) nondeterministic = True)
def test_iterator_stream(self): def test_iterator_stream(self):
d = {u'a': 1, u'b': 2, u'c': 3} d = {u'a': 1, u'b': 2, u'c': 3}
r = r'b5(b5b1016.9.84){3}84' r = r'a8(88a883a46.0082a30.){3}'
if hasattr(d, 'iteritems'): if hasattr(d, 'iteritems'):
# python 2 # python 2
bs = _e(d.iteritems()) bs = _e(d.iteritems())
@ -160,10 +162,10 @@ class BinaryCodecTests(unittest.TestCase):
self.assertEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)]) self.assertEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)])
def test_long_sequence(self): def test_long_sequence(self):
self._roundtrip((False,) * 14, _buf(0xb5, b'\x80' * 14, 0x84)) self._roundtrip((False,) * 14, _buf(0xa8, b'\x81\xa0' * 14))
self._roundtrip((False,) * 15, _buf(0xb5, b'\x80' * 15, 0x84)) self._roundtrip((False,) * 15, _buf(0xa8, b'\x81\xa0' * 15))
self._roundtrip((False,) * 100, _buf(0xb5, b'\x80' * 100, 0x84)) self._roundtrip((False,) * 100, _buf(0xa8, b'\x81\xa0' * 100))
self._roundtrip((False,) * 200, _buf(0xb5, b'\x80' * 200, 0x84)) self._roundtrip((False,) * 200, _buf(0xa8, b'\x81\xa0' * 200))
def test_embedded_id(self): def test_embedded_id(self):
class A: class A:
@ -173,12 +175,12 @@ class BinaryCodecTests(unittest.TestCase):
a2 = Embedded(A(1)) a2 = Embedded(A(1))
self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id)) self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id))
self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id)) self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id))
self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86) self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0xab)
self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86) self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0xab)
def test_decode_embedded_absent(self): def test_decode_embedded_absent(self):
with self.assertRaises(DecodeError): with self.assertRaises(DecodeError):
decode(b'\x86\xa0\xff', decode_embedded=None) decode(b'\xab\xa3\xff', decode_embedded=None)
def test_encode_embedded(self): def test_encode_embedded(self):
objects = [] objects = []
@ -186,18 +188,18 @@ class BinaryCodecTests(unittest.TestCase):
objects.append(p) objects.append(p)
return len(objects) - 1 return len(objects) - 1
self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc), self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc),
b'\xb5\x86\x90\x86\x91\x84') b'\xa8\x82\xab\xa3\x83\xab\xa3\x01')
def test_decode_embedded(self): def test_decode_embedded(self):
objects = [123, 234] objects = [123, 234]
def dec(v): def dec(v):
return objects[v] return objects[v]
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec), self.assertEqual(decode(b'\xa8\x82\xab\xa3\x83\xab\xa3\x01', decode_embedded = dec),
(Embedded(123), Embedded(234))) (Embedded(123), Embedded(234)))
def load_binary_samples(): def load_binary_samples():
with open(os.path.join(os.path.dirname(__file__), 'samples.bin'), 'rb') as f: with open(os.path.join(os.path.dirname(__file__), 'samples.bin'), 'rb') as f:
return Decoder(f.read(), include_annotations=True, decode_embedded=lambda x: x).next() return Decoder(include_annotations=True, decode_embedded=lambda x: x).next(f.read())
def load_text_samples(): def load_text_samples():
with open(os.path.join(os.path.dirname(__file__), 'samples.pr'), 'rt') as f: with open(os.path.join(os.path.dirname(__file__), 'samples.pr'), 'rt') as f:
@ -260,7 +262,14 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm):
def test_back(self): self.assertEqual(self.DS(binaryForm), back) def test_back(self): self.assertEqual(self.DS(binaryForm), back)
def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm) def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm)
def test_encode(self): self.assertEqual(self.E(forward), binaryForm) def test_encode(self): self.assertEqual(self.E(forward), binaryForm)
def test_encode_canonical(self): self.assertEqual(self.EC(annotatedTextForm), binaryForm) def test_encode_canonical_annotated(self):
a = self.ECA(annotatedTextForm)
b = binaryForm
if a != b:
print('\nval:', annotatedTextForm)
print('ECA:', a.hex())
print('bin:', b.hex())
self.assertEqual(self.ECA(annotatedTextForm), binaryForm)
def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm) def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm)
add_method(d, tName, test_match_expected) add_method(d, tName, test_match_expected)
add_method(d, tName, test_roundtrip) add_method(d, tName, test_roundtrip)
@ -270,7 +279,7 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm):
if variant in ['normal']: if variant in ['normal']:
add_method(d, tName, test_encode) add_method(d, tName, test_encode)
if variant in ['nondeterministic']: if variant in ['nondeterministic']:
add_method(d, tName, test_encode_canonical) add_method(d, tName, test_encode_canonical_annotated)
if variant in ['normal', 'nondeterministic']: if variant in ['normal', 'nondeterministic']:
add_method(d, tName, test_encode_ann) add_method(d, tName, test_encode_ann)
@ -322,8 +331,8 @@ class CommonTestSuite(unittest.TestCase):
def E(self, v): def E(self, v):
return encode(v, encode_embedded=lambda x: x) return encode(v, encode_embedded=lambda x: x)
def EC(self, v): def ECA(self, v):
return encode(v, encode_embedded=lambda x: x, canonicalize=True) return encode(v, encode_embedded=lambda x: x, canonicalize=True, include_annotations=True)
class RecordTests(unittest.TestCase): class RecordTests(unittest.TestCase):
def test_getters(self): def test_getters(self):