From 15d14cdafe272fd647e42a7c826ef37efc5c865e Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Sat, 31 Aug 2019 20:52:32 +0100 Subject: [PATCH] Common tests for Python impl --- implementations/python/preserves/__init__.py | 2 +- implementations/python/preserves/preserves.py | 23 ++- .../python/preserves/test_preserves.py | 143 ++++++++++++++++-- 3 files changed, 148 insertions(+), 20 deletions(-) diff --git a/implementations/python/preserves/__init__.py b/implementations/python/preserves/__init__.py index 68c5323..de523b1 100644 --- a/implementations/python/preserves/__init__.py +++ b/implementations/python/preserves/__init__.py @@ -2,7 +2,7 @@ from .preserves import Float, Symbol, Record, ImmutableDict from .preserves import DecodeError, EncodeError, ShortPacket -from .preserves import Decoder, Encoder +from .preserves import Decoder, Encoder, decode, decode_with_annotations, encode from .preserves import Annotated, is_annotated, strip_annotations, annotate diff --git a/implementations/python/preserves/preserves.py b/implementations/python/preserves/preserves.py index 2a9f253..bbab159 100644 --- a/implementations/python/preserves/preserves.py +++ b/implementations/python/preserves/preserves.py @@ -246,6 +246,9 @@ class Annotated(object): def strip(self, depth=inf): return strip_annotations(self, depth) + def peel(self): + return strip_annotations(self, 1) + def __eq__(self, other): if other.__class__ is self.__class__: return self.item == other.item @@ -256,6 +259,9 @@ class Annotated(object): def __hash__(self): return hash(self.item) + def __repr__(self): + return ' '.join(list('@' + repr(a) for a in self.annotations) + [repr(self.item)]) + def is_annotated(v): return isinstance(v, Annotated) @@ -265,11 +271,11 @@ def strip_annotations(v, depth=inf): next_depth = depth - 1 def walk(v): - strip_annotations(v, next_depth) + return strip_annotations(v, next_depth) v = v.item if isinstance(v, Record): - return Record(walk(v.key), tuple(walk(f) for f in v.fields)) + return Record(strip_annotations(v.key, depth), tuple(walk(f) for f in v.fields)) elif isinstance(v, list): return tuple(walk(f) for f in v) elif isinstance(v, tuple): @@ -391,7 +397,7 @@ class Decoder(Codec): return Annotated(v) if self.include_annotations else v def unshift_annotation(self, a, v): - if this.include_annotations: + if self.include_annotations: v.annotations.insert(0, a) return v @@ -438,6 +444,12 @@ class Decoder(Codec): self.index = start return None +def decode(bs, placeholders={}): + return Decoder(packet=bs, placeholders=placeholders).next() + +def decode_with_annotations(bs, placeholders={}): + return Decoder(packet=bs, placeholders=placeholders, include_annotations=True).next() + class Encoder(Codec): def __init__(self, placeholders={}): super(Encoder, self).__init__() @@ -528,3 +540,8 @@ class Encoder(Codec): except TypeError: raise EncodeError('Cannot encode %r' % (v,)) self.encodestream(2, 1, i) + +def encode(v, placeholders={}): + e = Encoder(placeholders=placeholders) + e.append(v) + return e.contents() diff --git a/implementations/python/preserves/test_preserves.py b/implementations/python/preserves/test_preserves.py index bed5e07..2d61439 100644 --- a/implementations/python/preserves/test_preserves.py +++ b/implementations/python/preserves/test_preserves.py @@ -1,5 +1,6 @@ from .preserves import * import unittest +import sys if isinstance(chr(123), bytes): def _byte(x): @@ -32,30 +33,18 @@ def _varint(v): return e.contents() def _d(bs): - d = Decoder(bs, placeholders={ + return decode(bs, placeholders={ 0: Symbol('discard'), 1: Symbol('capture'), 2: Symbol('observe'), }) - return d.next() - -_all_encoded = set() - -def tearDownModule(): - print() - for bs in sorted(_all_encoded): - print(_hex(bs)) def _e(v): - e = Encoder(placeholders={ + return encode(v, placeholders={ Symbol('discard'): 0, Symbol('capture'): 1, Symbol('observe'): 2, }) - e.append(v) - bs = e.contents() - _all_encoded.add(bs) - return bs def _R(k, *args): return Record(Symbol(k), args) @@ -198,8 +187,130 @@ class CodecTests(unittest.TestCase): _buf(0x29, 0x25, 0x63, 'abc', 0x04, 0x25, 0x63, 'def', 0x04, 0x04), back=(u'abc', u'def')) - def test_common_test_suite(self): - self.fail('Common test suite needs to be implemented') +def add_method(d, tName, fn): + if hasattr(fn, 'func_name'): + # python2 + fname = str(fn.func_name + '_' + tName) + fn.func_name = fname + else: + # python3 + fname = str(fn.__name__ + '_' + tName) + fn.__name__ = fname + d[fname] = fn + +expected_values = { + "annotation1": { "forward": annotate(9, u"abc"), "back": 9 }, + "annotation2": { "forward": annotate([[], annotate([], u"x")], u"abc", u"def"), "back": ((), ()) }, + "annotation3": { "forward": annotate(5, annotate(2, 1), annotate(4, 3)), "back": 5 }, + "annotation5": { "forward": annotate(_R('R', annotate(Symbol('f'), Symbol('af'))), + Symbol('ar')), + "back": _R('R', Symbol('f')) }, + "annotation6": { "forward": Record(annotate(Symbol('R'), Symbol('ar')), + [annotate(Symbol('f'), Symbol('af'))]), + "back": _R('R', Symbol('f')) }, + "annotation7": { "forward": annotate([], Symbol('a'), Symbol('b'), Symbol('c')), + "back": () }, + "bytes1": { "forward": BinaryStream([b'he', b'll', b'o']), "back": b'hello' }, + "list1": { "forward": SequenceStream([1, 2, 3, 4]), "back": (1, 2, 3, 4) }, + "list2": { "forward": SequenceStream([ StringStream([b'abc']), StringStream([b'def']) ]), + "back": (u"abc", u"def") }, + "list3": { "forward": SequenceStream([[u"a", 1], [u"b", 2], [u"c", 3]]), + "back": ((u"a", 1), (u"b", 2), (u"c", 3)) }, + "record2": { "value": _R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))) }, + "string0a": { "forward": StringStream([]), "back": u'' }, + "string1": { "forward": StringStream([b'he', b'll', b'o']), "back": u'hello' }, + "string2": { "forward": StringStream([b'he', b'llo']), "back": u'hello' }, + "symbol1": { "forward": SymbolStream([b'he', b'll', b'o']), "back": Symbol('hello') }, +} + +def get_expected_values(tName, textForm): + entry = expected_values.get(tName, {"value": textForm}) + if 'value' in entry: + return { "forward": entry['value'], "back": entry['value'] } + elif 'forward' in entry and 'back' in entry: + return entry + else: + raise Exception('Invalid expected_values entry for ' + tName) + +def install_test(d, variant, tName, binaryForm, annotatedTextForm): + textForm = annotatedTextForm.strip() + entry = get_expected_values(tName, textForm) + forward = entry['forward'] + back = entry['back'] + def test_match_expected(self): self.assertEqual(textForm, back) + def test_roundtrip(self): self.assertEqual(self.DS(self.E(textForm)), back) + def test_forward(self): self.assertEqual(self.DS(self.E(forward)), back) + def test_back(self): self.assertEqual(self.DS(binaryForm), back) + def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm) + def test_encode(self): self.assertEqual(self.E(forward), binaryForm) + def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm) + add_method(d, tName, test_match_expected) + add_method(d, tName, test_roundtrip) + add_method(d, tName, test_forward) + add_method(d, tName, test_back) + add_method(d, tName, test_back_ann) + if variant not in ['nondeterministic']: + add_method(d, tName, test_encode) + if variant not in ['nondeterministic', 'streaming']: + add_method(d, tName, test_encode_ann) + +def install_exn_test(d, tName, bs, check_proc): + def test_exn(self): + try: + self.D(bs) + except: + check_proc(self, sys.exc_info()[1]) + return + self.fail('did not fail as expected') + add_method(d, tName, test_exn) + +class CommonTestSuite(unittest.TestCase): + import os + with open(os.path.join(os.path.dirname(__file__), + '../../../tests/samples.bin'), 'rb') as f: + samples = Decoder(f.read(), include_annotations=True).next() + + TestCases = Record.makeConstructor('TestCases', 'mapping cases') + ExpectedPlaceholderMapping = Record.makeConstructor('ExpectedPlaceholderMapping', 'table') + + m = TestCases._mapping(samples.peel()).strip() + placeholders_decode = ExpectedPlaceholderMapping._table(m) + placeholders_encode = dict((v,k) for (k,v) in placeholders_decode.items()) + + tests = TestCases._cases(samples.peel()).peel() + for (tName0, t0) in tests.items(): + tName = tName0.strip().name + t = t0.peel() + if t.key == Symbol('Test'): + install_test(locals(), 'normal', tName, t[0].strip(), t[1]) + elif t.key == Symbol('StreamingTest'): + install_test(locals(), 'streaming', tName, t[0].strip(), t[1]) + elif t.key == Symbol('NondeterministicTest'): + install_test(locals(), 'nondeterministic', tName, t[0].strip(), t[1]) + elif t.key == Symbol('DecodeError'): + def expected_err(self, e): + self.assertIsInstance(e, DecodeError) + self.assertNotIsInstance(e, ShortPacket) + install_exn_test(locals(), tName, t[0].strip(), expected_err) + elif t.key == Symbol('DecodeShort'): + def expected_short(self, e): + self.assertIsInstance(e, ShortPacket) + install_exn_test(locals(), tName, t[0].strip(), expected_short) + elif t.key == Symbol('ParseError') or \ + t.key == Symbol('ParseShort'): + # Skipped for now, until we have an implementation of text syntax + pass + else: + raise Exception('Unsupported test kind', t.key) + + def DS(self, bs): + return decode(bs, placeholders=self.placeholders_decode) + + def D(self, bs): + return decode_with_annotations(bs, placeholders=self.placeholders_decode) + + def E(self, v): + return encode(v, placeholders=self.placeholders_encode) class RecordTests(unittest.TestCase): def test_getters(self):