Common tests for Python impl

This commit is contained in:
Tony Garnock-Jones 2019-08-31 20:52:32 +01:00
parent b2203b6f5b
commit 15d14cdafe
3 changed files with 148 additions and 20 deletions

View File

@ -2,7 +2,7 @@ from .preserves import Float, Symbol, Record, ImmutableDict
from .preserves import DecodeError, EncodeError, ShortPacket from .preserves import DecodeError, EncodeError, ShortPacket
from .preserves import Decoder, Encoder from .preserves import Decoder, Encoder, decode, decode_with_annotations, encode
from .preserves import Annotated, is_annotated, strip_annotations, annotate from .preserves import Annotated, is_annotated, strip_annotations, annotate

View File

@ -246,6 +246,9 @@ class Annotated(object):
def strip(self, depth=inf): def strip(self, depth=inf):
return strip_annotations(self, depth) return strip_annotations(self, depth)
def peel(self):
return strip_annotations(self, 1)
def __eq__(self, other): def __eq__(self, other):
if other.__class__ is self.__class__: if other.__class__ is self.__class__:
return self.item == other.item return self.item == other.item
@ -256,6 +259,9 @@ class Annotated(object):
def __hash__(self): def __hash__(self):
return hash(self.item) return hash(self.item)
def __repr__(self):
return ' '.join(list('@' + repr(a) for a in self.annotations) + [repr(self.item)])
def is_annotated(v): def is_annotated(v):
return isinstance(v, Annotated) return isinstance(v, Annotated)
@ -265,11 +271,11 @@ def strip_annotations(v, depth=inf):
next_depth = depth - 1 next_depth = depth - 1
def walk(v): def walk(v):
strip_annotations(v, next_depth) return strip_annotations(v, next_depth)
v = v.item v = v.item
if isinstance(v, Record): if isinstance(v, Record):
return Record(walk(v.key), tuple(walk(f) for f in v.fields)) return Record(strip_annotations(v.key, depth), tuple(walk(f) for f in v.fields))
elif isinstance(v, list): elif isinstance(v, list):
return tuple(walk(f) for f in v) return tuple(walk(f) for f in v)
elif isinstance(v, tuple): elif isinstance(v, tuple):
@ -391,7 +397,7 @@ class Decoder(Codec):
return Annotated(v) if self.include_annotations else v return Annotated(v) if self.include_annotations else v
def unshift_annotation(self, a, v): def unshift_annotation(self, a, v):
if this.include_annotations: if self.include_annotations:
v.annotations.insert(0, a) v.annotations.insert(0, a)
return v return v
@ -438,6 +444,12 @@ class Decoder(Codec):
self.index = start self.index = start
return None return None
def decode(bs, placeholders={}):
return Decoder(packet=bs, placeholders=placeholders).next()
def decode_with_annotations(bs, placeholders={}):
return Decoder(packet=bs, placeholders=placeholders, include_annotations=True).next()
class Encoder(Codec): class Encoder(Codec):
def __init__(self, placeholders={}): def __init__(self, placeholders={}):
super(Encoder, self).__init__() super(Encoder, self).__init__()
@ -528,3 +540,8 @@ class Encoder(Codec):
except TypeError: except TypeError:
raise EncodeError('Cannot encode %r' % (v,)) raise EncodeError('Cannot encode %r' % (v,))
self.encodestream(2, 1, i) self.encodestream(2, 1, i)
def encode(v, placeholders={}):
e = Encoder(placeholders=placeholders)
e.append(v)
return e.contents()

View File

@ -1,5 +1,6 @@
from .preserves import * from .preserves import *
import unittest import unittest
import sys
if isinstance(chr(123), bytes): if isinstance(chr(123), bytes):
def _byte(x): def _byte(x):
@ -32,30 +33,18 @@ def _varint(v):
return e.contents() return e.contents()
def _d(bs): def _d(bs):
d = Decoder(bs, placeholders={ return decode(bs, placeholders={
0: Symbol('discard'), 0: Symbol('discard'),
1: Symbol('capture'), 1: Symbol('capture'),
2: Symbol('observe'), 2: Symbol('observe'),
}) })
return d.next()
_all_encoded = set()
def tearDownModule():
print()
for bs in sorted(_all_encoded):
print(_hex(bs))
def _e(v): def _e(v):
e = Encoder(placeholders={ return encode(v, placeholders={
Symbol('discard'): 0, Symbol('discard'): 0,
Symbol('capture'): 1, Symbol('capture'): 1,
Symbol('observe'): 2, Symbol('observe'): 2,
}) })
e.append(v)
bs = e.contents()
_all_encoded.add(bs)
return bs
def _R(k, *args): def _R(k, *args):
return Record(Symbol(k), args) return Record(Symbol(k), args)
@ -198,8 +187,130 @@ class CodecTests(unittest.TestCase):
_buf(0x29, 0x25, 0x63, 'abc', 0x04, 0x25, 0x63, 'def', 0x04, 0x04), _buf(0x29, 0x25, 0x63, 'abc', 0x04, 0x25, 0x63, 'def', 0x04, 0x04),
back=(u'abc', u'def')) back=(u'abc', u'def'))
def test_common_test_suite(self): def add_method(d, tName, fn):
self.fail('Common test suite needs to be implemented') if hasattr(fn, 'func_name'):
# python2
fname = str(fn.func_name + '_' + tName)
fn.func_name = fname
else:
# python3
fname = str(fn.__name__ + '_' + tName)
fn.__name__ = fname
d[fname] = fn
expected_values = {
"annotation1": { "forward": annotate(9, u"abc"), "back": 9 },
"annotation2": { "forward": annotate([[], annotate([], u"x")], u"abc", u"def"), "back": ((), ()) },
"annotation3": { "forward": annotate(5, annotate(2, 1), annotate(4, 3)), "back": 5 },
"annotation5": { "forward": annotate(_R('R', annotate(Symbol('f'), Symbol('af'))),
Symbol('ar')),
"back": _R('R', Symbol('f')) },
"annotation6": { "forward": Record(annotate(Symbol('R'), Symbol('ar')),
[annotate(Symbol('f'), Symbol('af'))]),
"back": _R('R', Symbol('f')) },
"annotation7": { "forward": annotate([], Symbol('a'), Symbol('b'), Symbol('c')),
"back": () },
"bytes1": { "forward": BinaryStream([b'he', b'll', b'o']), "back": b'hello' },
"list1": { "forward": SequenceStream([1, 2, 3, 4]), "back": (1, 2, 3, 4) },
"list2": { "forward": SequenceStream([ StringStream([b'abc']), StringStream([b'def']) ]),
"back": (u"abc", u"def") },
"list3": { "forward": SequenceStream([[u"a", 1], [u"b", 2], [u"c", 3]]),
"back": ((u"a", 1), (u"b", 2), (u"c", 3)) },
"record2": { "value": _R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))) },
"string0a": { "forward": StringStream([]), "back": u'' },
"string1": { "forward": StringStream([b'he', b'll', b'o']), "back": u'hello' },
"string2": { "forward": StringStream([b'he', b'llo']), "back": u'hello' },
"symbol1": { "forward": SymbolStream([b'he', b'll', b'o']), "back": Symbol('hello') },
}
def get_expected_values(tName, textForm):
entry = expected_values.get(tName, {"value": textForm})
if 'value' in entry:
return { "forward": entry['value'], "back": entry['value'] }
elif 'forward' in entry and 'back' in entry:
return entry
else:
raise Exception('Invalid expected_values entry for ' + tName)
def install_test(d, variant, tName, binaryForm, annotatedTextForm):
textForm = annotatedTextForm.strip()
entry = get_expected_values(tName, textForm)
forward = entry['forward']
back = entry['back']
def test_match_expected(self): self.assertEqual(textForm, back)
def test_roundtrip(self): self.assertEqual(self.DS(self.E(textForm)), back)
def test_forward(self): self.assertEqual(self.DS(self.E(forward)), back)
def test_back(self): self.assertEqual(self.DS(binaryForm), back)
def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm)
def test_encode(self): self.assertEqual(self.E(forward), binaryForm)
def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm)
add_method(d, tName, test_match_expected)
add_method(d, tName, test_roundtrip)
add_method(d, tName, test_forward)
add_method(d, tName, test_back)
add_method(d, tName, test_back_ann)
if variant not in ['nondeterministic']:
add_method(d, tName, test_encode)
if variant not in ['nondeterministic', 'streaming']:
add_method(d, tName, test_encode_ann)
def install_exn_test(d, tName, bs, check_proc):
def test_exn(self):
try:
self.D(bs)
except:
check_proc(self, sys.exc_info()[1])
return
self.fail('did not fail as expected')
add_method(d, tName, test_exn)
class CommonTestSuite(unittest.TestCase):
import os
with open(os.path.join(os.path.dirname(__file__),
'../../../tests/samples.bin'), 'rb') as f:
samples = Decoder(f.read(), include_annotations=True).next()
TestCases = Record.makeConstructor('TestCases', 'mapping cases')
ExpectedPlaceholderMapping = Record.makeConstructor('ExpectedPlaceholderMapping', 'table')
m = TestCases._mapping(samples.peel()).strip()
placeholders_decode = ExpectedPlaceholderMapping._table(m)
placeholders_encode = dict((v,k) for (k,v) in placeholders_decode.items())
tests = TestCases._cases(samples.peel()).peel()
for (tName0, t0) in tests.items():
tName = tName0.strip().name
t = t0.peel()
if t.key == Symbol('Test'):
install_test(locals(), 'normal', tName, t[0].strip(), t[1])
elif t.key == Symbol('StreamingTest'):
install_test(locals(), 'streaming', tName, t[0].strip(), t[1])
elif t.key == Symbol('NondeterministicTest'):
install_test(locals(), 'nondeterministic', tName, t[0].strip(), t[1])
elif t.key == Symbol('DecodeError'):
def expected_err(self, e):
self.assertIsInstance(e, DecodeError)
self.assertNotIsInstance(e, ShortPacket)
install_exn_test(locals(), tName, t[0].strip(), expected_err)
elif t.key == Symbol('DecodeShort'):
def expected_short(self, e):
self.assertIsInstance(e, ShortPacket)
install_exn_test(locals(), tName, t[0].strip(), expected_short)
elif t.key == Symbol('ParseError') or \
t.key == Symbol('ParseShort'):
# Skipped for now, until we have an implementation of text syntax
pass
else:
raise Exception('Unsupported test kind', t.key)
def DS(self, bs):
return decode(bs, placeholders=self.placeholders_decode)
def D(self, bs):
return decode_with_annotations(bs, placeholders=self.placeholders_decode)
def E(self, v):
return encode(v, placeholders=self.placeholders_encode)
class RecordTests(unittest.TestCase): class RecordTests(unittest.TestCase):
def test_getters(self): def test_getters(self):