Common tests for Python impl

This commit is contained in:
Tony Garnock-Jones 2019-08-31 20:52:32 +01:00
parent b2203b6f5b
commit 15d14cdafe
3 changed files with 148 additions and 20 deletions

View File

@ -2,7 +2,7 @@ from .preserves import Float, Symbol, Record, ImmutableDict
from .preserves import DecodeError, EncodeError, ShortPacket
from .preserves import Decoder, Encoder
from .preserves import Decoder, Encoder, decode, decode_with_annotations, encode
from .preserves import Annotated, is_annotated, strip_annotations, annotate

View File

@ -246,6 +246,9 @@ class Annotated(object):
def strip(self, depth=inf):
return strip_annotations(self, depth)
def peel(self):
return strip_annotations(self, 1)
def __eq__(self, other):
if other.__class__ is self.__class__:
return self.item == other.item
@ -256,6 +259,9 @@ class Annotated(object):
def __hash__(self):
return hash(self.item)
def __repr__(self):
return ' '.join(list('@' + repr(a) for a in self.annotations) + [repr(self.item)])
def is_annotated(v):
return isinstance(v, Annotated)
@ -265,11 +271,11 @@ def strip_annotations(v, depth=inf):
next_depth = depth - 1
def walk(v):
strip_annotations(v, next_depth)
return strip_annotations(v, next_depth)
v = v.item
if isinstance(v, Record):
return Record(walk(v.key), tuple(walk(f) for f in v.fields))
return Record(strip_annotations(v.key, depth), tuple(walk(f) for f in v.fields))
elif isinstance(v, list):
return tuple(walk(f) for f in v)
elif isinstance(v, tuple):
@ -391,7 +397,7 @@ class Decoder(Codec):
return Annotated(v) if self.include_annotations else v
def unshift_annotation(self, a, v):
if this.include_annotations:
if self.include_annotations:
v.annotations.insert(0, a)
return v
@ -438,6 +444,12 @@ class Decoder(Codec):
self.index = start
return None
def decode(bs, placeholders={}):
return Decoder(packet=bs, placeholders=placeholders).next()
def decode_with_annotations(bs, placeholders={}):
return Decoder(packet=bs, placeholders=placeholders, include_annotations=True).next()
class Encoder(Codec):
def __init__(self, placeholders={}):
super(Encoder, self).__init__()
@ -528,3 +540,8 @@ class Encoder(Codec):
except TypeError:
raise EncodeError('Cannot encode %r' % (v,))
self.encodestream(2, 1, i)
def encode(v, placeholders={}):
e = Encoder(placeholders=placeholders)
e.append(v)
return e.contents()

View File

@ -1,5 +1,6 @@
from .preserves import *
import unittest
import sys
if isinstance(chr(123), bytes):
def _byte(x):
@ -32,30 +33,18 @@ def _varint(v):
return e.contents()
def _d(bs):
d = Decoder(bs, placeholders={
return decode(bs, placeholders={
0: Symbol('discard'),
1: Symbol('capture'),
2: Symbol('observe'),
})
return d.next()
_all_encoded = set()
def tearDownModule():
print()
for bs in sorted(_all_encoded):
print(_hex(bs))
def _e(v):
e = Encoder(placeholders={
return encode(v, placeholders={
Symbol('discard'): 0,
Symbol('capture'): 1,
Symbol('observe'): 2,
})
e.append(v)
bs = e.contents()
_all_encoded.add(bs)
return bs
def _R(k, *args):
return Record(Symbol(k), args)
@ -198,8 +187,130 @@ class CodecTests(unittest.TestCase):
_buf(0x29, 0x25, 0x63, 'abc', 0x04, 0x25, 0x63, 'def', 0x04, 0x04),
back=(u'abc', u'def'))
def test_common_test_suite(self):
self.fail('Common test suite needs to be implemented')
def add_method(d, tName, fn):
if hasattr(fn, 'func_name'):
# python2
fname = str(fn.func_name + '_' + tName)
fn.func_name = fname
else:
# python3
fname = str(fn.__name__ + '_' + tName)
fn.__name__ = fname
d[fname] = fn
expected_values = {
"annotation1": { "forward": annotate(9, u"abc"), "back": 9 },
"annotation2": { "forward": annotate([[], annotate([], u"x")], u"abc", u"def"), "back": ((), ()) },
"annotation3": { "forward": annotate(5, annotate(2, 1), annotate(4, 3)), "back": 5 },
"annotation5": { "forward": annotate(_R('R', annotate(Symbol('f'), Symbol('af'))),
Symbol('ar')),
"back": _R('R', Symbol('f')) },
"annotation6": { "forward": Record(annotate(Symbol('R'), Symbol('ar')),
[annotate(Symbol('f'), Symbol('af'))]),
"back": _R('R', Symbol('f')) },
"annotation7": { "forward": annotate([], Symbol('a'), Symbol('b'), Symbol('c')),
"back": () },
"bytes1": { "forward": BinaryStream([b'he', b'll', b'o']), "back": b'hello' },
"list1": { "forward": SequenceStream([1, 2, 3, 4]), "back": (1, 2, 3, 4) },
"list2": { "forward": SequenceStream([ StringStream([b'abc']), StringStream([b'def']) ]),
"back": (u"abc", u"def") },
"list3": { "forward": SequenceStream([[u"a", 1], [u"b", 2], [u"c", 3]]),
"back": ((u"a", 1), (u"b", 2), (u"c", 3)) },
"record2": { "value": _R('observe', _R('speak', _R('discard'), _R('capture', _R('discard')))) },
"string0a": { "forward": StringStream([]), "back": u'' },
"string1": { "forward": StringStream([b'he', b'll', b'o']), "back": u'hello' },
"string2": { "forward": StringStream([b'he', b'llo']), "back": u'hello' },
"symbol1": { "forward": SymbolStream([b'he', b'll', b'o']), "back": Symbol('hello') },
}
def get_expected_values(tName, textForm):
entry = expected_values.get(tName, {"value": textForm})
if 'value' in entry:
return { "forward": entry['value'], "back": entry['value'] }
elif 'forward' in entry and 'back' in entry:
return entry
else:
raise Exception('Invalid expected_values entry for ' + tName)
def install_test(d, variant, tName, binaryForm, annotatedTextForm):
textForm = annotatedTextForm.strip()
entry = get_expected_values(tName, textForm)
forward = entry['forward']
back = entry['back']
def test_match_expected(self): self.assertEqual(textForm, back)
def test_roundtrip(self): self.assertEqual(self.DS(self.E(textForm)), back)
def test_forward(self): self.assertEqual(self.DS(self.E(forward)), back)
def test_back(self): self.assertEqual(self.DS(binaryForm), back)
def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm)
def test_encode(self): self.assertEqual(self.E(forward), binaryForm)
def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm)
add_method(d, tName, test_match_expected)
add_method(d, tName, test_roundtrip)
add_method(d, tName, test_forward)
add_method(d, tName, test_back)
add_method(d, tName, test_back_ann)
if variant not in ['nondeterministic']:
add_method(d, tName, test_encode)
if variant not in ['nondeterministic', 'streaming']:
add_method(d, tName, test_encode_ann)
def install_exn_test(d, tName, bs, check_proc):
def test_exn(self):
try:
self.D(bs)
except:
check_proc(self, sys.exc_info()[1])
return
self.fail('did not fail as expected')
add_method(d, tName, test_exn)
class CommonTestSuite(unittest.TestCase):
import os
with open(os.path.join(os.path.dirname(__file__),
'../../../tests/samples.bin'), 'rb') as f:
samples = Decoder(f.read(), include_annotations=True).next()
TestCases = Record.makeConstructor('TestCases', 'mapping cases')
ExpectedPlaceholderMapping = Record.makeConstructor('ExpectedPlaceholderMapping', 'table')
m = TestCases._mapping(samples.peel()).strip()
placeholders_decode = ExpectedPlaceholderMapping._table(m)
placeholders_encode = dict((v,k) for (k,v) in placeholders_decode.items())
tests = TestCases._cases(samples.peel()).peel()
for (tName0, t0) in tests.items():
tName = tName0.strip().name
t = t0.peel()
if t.key == Symbol('Test'):
install_test(locals(), 'normal', tName, t[0].strip(), t[1])
elif t.key == Symbol('StreamingTest'):
install_test(locals(), 'streaming', tName, t[0].strip(), t[1])
elif t.key == Symbol('NondeterministicTest'):
install_test(locals(), 'nondeterministic', tName, t[0].strip(), t[1])
elif t.key == Symbol('DecodeError'):
def expected_err(self, e):
self.assertIsInstance(e, DecodeError)
self.assertNotIsInstance(e, ShortPacket)
install_exn_test(locals(), tName, t[0].strip(), expected_err)
elif t.key == Symbol('DecodeShort'):
def expected_short(self, e):
self.assertIsInstance(e, ShortPacket)
install_exn_test(locals(), tName, t[0].strip(), expected_short)
elif t.key == Symbol('ParseError') or \
t.key == Symbol('ParseShort'):
# Skipped for now, until we have an implementation of text syntax
pass
else:
raise Exception('Unsupported test kind', t.key)
def DS(self, bs):
return decode(bs, placeholders=self.placeholders_decode)
def D(self, bs):
return decode_with_annotations(bs, placeholders=self.placeholders_decode)
def E(self, v):
return encode(v, placeholders=self.placeholders_encode)
class RecordTests(unittest.TestCase):
def test_getters(self):