From 67613877cee2f67f61dab4d1ec53565bf8f0819c Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Tue, 8 Nov 2022 16:13:30 +0100 Subject: [PATCH] Update Python implementation; repair comparison routines --- implementations/python/preserves/binary.py | 2 +- implementations/python/preserves/compare.py | 18 +++- implementations/python/preserves/text.py | 92 +++++++------------ implementations/python/preserves/values.py | 50 +++++++++- implementations/python/tests/test_compare.py | 4 +- implementations/python/tests/test_path.py | 34 +++---- .../python/tests/test_preserves.py | 85 ++++++++--------- implementations/python/tests/test_schema.py | 10 +- implementations/python/tests/utils.py | 9 ++ 9 files changed, 172 insertions(+), 132 deletions(-) create mode 100644 implementations/python/tests/utils.py diff --git a/implementations/python/preserves/binary.py b/implementations/python/preserves/binary.py index 78ec1dd..8b246f4 100644 --- a/implementations/python/preserves/binary.py +++ b/implementations/python/preserves/binary.py @@ -72,7 +72,7 @@ class Decoder(BinaryCodec): tag = self.nextbyte() if tag == 0x80: return self.wrap(False) if tag == 0x81: return self.wrap(True) - if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0])) + if tag == 0x82: return self.wrap(Float.from_bytes(self.nextbytes(4))) if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0]) if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker') if tag == 0x85: diff --git a/implementations/python/preserves/compare.py b/implementations/python/preserves/compare.py index 87df7ff..1572426 100644 --- a/implementations/python/preserves/compare.py +++ b/implementations/python/preserves/compare.py @@ -2,7 +2,7 @@ import numbers from enum import Enum from functools import cmp_to_key -from .values import preserve, Float, Embedded, Record, Symbol +from .values import preserve, Float, Embedded, Record, Symbol, cmp_floats, _unwrap from .compat import basestring_ class TypeNumber(Enum): @@ -19,7 +19,7 @@ class TypeNumber(Enum): SET = 9 DICTIONARY = 10 - EMBEDDED = 10 + EMBEDDED = 11 def type_number(v): if hasattr(v, '__preserve__'): @@ -84,12 +84,17 @@ def _item_key(item): return item[0] def _eq(a, b): + a = _unwrap(a) + b = _unwrap(b) ta = type_number(a) tb = type_number(b) if ta != tb: return False + if ta == TypeNumber.DOUBLE: + return cmp_floats(a, b) == 0 + if ta == TypeNumber.EMBEDDED: - return ta.embeddedValue == tb.embeddedValue + return _eq(a.embeddedValue, b.embeddedValue) if ta == TypeNumber.RECORD: return _eq(a.key, b.key) and _eq_sequences(a.fields, b.fields) @@ -118,13 +123,18 @@ def _cmp_sequences(aa, bb): return len(aa) - len(bb) def _cmp(a, b): + a = _unwrap(a) + b = _unwrap(b) ta = type_number(a) tb = type_number(b) if ta.value < tb.value: return -1 if tb.value < ta.value: return 1 + if ta == TypeNumber.DOUBLE: + return cmp_floats(a, b) + if ta == TypeNumber.EMBEDDED: - return _simplecmp(ta.embeddedValue, tb.embeddedValue) + return _cmp(a.embeddedValue, b.embeddedValue) if ta == TypeNumber.RECORD: v = _cmp(a.key, b.key) diff --git a/implementations/python/preserves/text.py b/implementations/python/preserves/text.py index 424441b..321f8b3 100644 --- a/implementations/python/preserves/text.py +++ b/implementations/python/preserves/text.py @@ -1,6 +1,7 @@ import numbers import struct import base64 +import math from .values import * from .error import * @@ -9,6 +10,8 @@ from .binary import Decoder class TextCodec(object): pass +NUMBER_RE = re.compile(r'^([-+]?\d+)(((\.\d+([eE][-+]?\d+)?)|([eE][-+]?\d+))([fF]?))?$') + class Parser(TextCodec): def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=lambda x: x): super(Parser, self).__init__() @@ -66,50 +69,6 @@ class Parser(TextCodec): return self.wrap(u''.join(s)) s.append(c) - def read_intpart(self, acc, c): - if c == '0': - acc.append(c) - else: - self.read_digit1(acc, c) - return self.read_fracexp(acc) - - def read_fracexp(self, acc): - is_float = False - if self.peek() == '.': - is_float = True - acc.append(self.nextchar()) - self.read_digit1(acc, self.nextchar()) - if self.peek() in 'eE': - acc.append(self.nextchar()) - return self.read_sign_and_exp(acc) - else: - return self.finish_number(acc, is_float) - - def read_sign_and_exp(self, acc): - if self.peek() in '+-': - acc.append(self.nextchar()) - self.read_digit1(acc, self.nextchar()) - return self.finish_number(acc, True) - - def finish_number(self, acc, is_float): - if is_float: - if self.peek() in 'fF': - self.skip() - return Float(float(u''.join(acc))) - else: - return float(u''.join(acc)) - else: - return int(u''.join(acc)) - - def read_digit1(self, acc, c): - if not c.isdigit(): - raise DecodeError('Incomplete number') - acc.append(c) - while not self._atend(): - if not self.peek().isdigit(): - break - acc.append(self.nextchar()) - def read_stringlike(self, terminator, hexescape, hexescaper): acc = [] while True: @@ -186,6 +145,16 @@ class Parser(TextCodec): if c == '=': continue acc.append(c) + def read_hex_float(self, bytecount): + if self.nextchar() != '"': + raise DecodeError('Missing open-double-quote in hex-encoded floating-point number') + bs = self.read_hex_binary() + if len(bs) != bytecount: + raise DecodeError('Incorrect number of bytes in hex-encoded floating-point number') + if bytecount == 4: return Float.from_bytes(bs) + if bytecount == 8: return struct.unpack('>d', bs)[0] + raise DecodeError('Unsupported byte count in hex-encoded floating-point number') + def upto(self, delimiter): vs = [] while True: @@ -208,14 +177,24 @@ class Parser(TextCodec): raise DecodeError('Missing expected key/value separator') acc.append(self.next()) - def read_raw_symbol(self, acc): + def read_raw_symbol_or_number(self, acc): while not self._atend(): c = self.peek() if c.isspace() or c in '(){}[]<>";,@#:|': break self.skip() acc.append(c) - return Symbol(u''.join(acc)) + acc = u''.join(acc) + m = NUMBER_RE.match(acc) + if m: + if m[2] is None: + return int(m[1]) + elif m[7] == '': + return float(m[1] + m[3]) + else: + return Float(float(m[1] + m[3])) + else: + return Symbol(acc) def wrap(self, v): return Annotated(v) if self.include_annotations else v @@ -223,12 +202,6 @@ class Parser(TextCodec): def next(self): self.skip_whitespace() c = self.peek() - if c == '-': - self.skip() - return self.wrap(self.read_intpart(['-'], self.nextchar())) - if c.isdigit(): - self.skip() - return self.wrap(self.read_intpart([], c)) if c == '"': self.skip() return self.wrap(self.read_string('"')) @@ -251,9 +224,11 @@ class Parser(TextCodec): if c == '{': return self.wrap(frozenset(self.upto('}'))) if c == '"': return self.wrap(self.read_literal_binary()) if c == 'x': - if self.nextchar() != '"': - raise DecodeError('Expected open-quote at start of hex ByteString') - return self.wrap(self.read_hex_binary()) + c = self.nextchar() + if c == '"': return self.wrap(self.read_hex_binary()) + if c == 'f': return self.wrap(self.read_hex_float(4)) + if c == 'd': return self.wrap(self.read_hex_float(8)) + raise DecodeError('Invalid #x syntax') if c == '[': return self.wrap(self.read_base64_binary()) if c == '=': old_ann = self.include_annotations @@ -286,7 +261,7 @@ class Parser(TextCodec): if c in '>]}': raise DecodeError('Unexpected ' + c) self.skip() - return self.wrap(self.read_raw_symbol([c])) + return self.wrap(self.read_raw_symbol_or_number([c])) def try_next(self): start = self.index @@ -385,7 +360,10 @@ class Formatter(TextCodec): elif v is True: self.chunks.append('#t') elif isinstance(v, float): - self.chunks.append(repr(v)) + if math.isnan(v) or math.isinf(v): + self.chunks.append('#xd"' + struct.pack('>d', v).hex() + '"') + else: + self.chunks.append(repr(v)) elif isinstance(v, numbers.Number): self.chunks.append('%d' % (v,)) elif isinstance(v, bytes): diff --git a/implementations/python/preserves/values.py b/implementations/python/preserves/values.py index b96bba3..cdb2fc0 100644 --- a/implementations/python/preserves/values.py +++ b/implementations/python/preserves/values.py @@ -1,6 +1,7 @@ import re import sys import struct +import math from .error import DecodeError @@ -9,6 +10,16 @@ def preserve(v): v = v.__preserve__() return v +def float_to_int(v): + return struct.unpack('>Q', struct.pack('>d', v))[0] + +def cmp_floats(a, b): + a = float_to_int(a) + b = float_to_int(b) + if a & 0x8000000000000000: a = a ^ 0x7fffffffffffffff + if b & 0x8000000000000000: b = b ^ 0x7fffffffffffffff + return a - b + class Float(object): def __init__(self, value): self.value = value @@ -16,7 +27,12 @@ class Float(object): def __eq__(self, other): other = _unwrap(other) if other.__class__ is self.__class__: - return self.value == other.value + return cmp_floats(self.value, other.value) == 0 + + def __lt__(self, other): + other = _unwrap(other) + if other.__class__ is self.__class__: + return cmp_floats(self.value, other.value) < 0 def __ne__(self, other): return not self.__eq__(other) @@ -27,15 +43,41 @@ class Float(object): def __repr__(self): return 'Float(' + repr(self.value) + ')' + def _to_bytes(self): + if math.isnan(self.value) or math.isinf(self.value): + dbs = struct.pack('>d', self.value) + vd = struct.unpack('>Q', dbs)[0] + sign = vd >> 63 + payload = (vd >> 29) & 0x007fffff + vf = (sign << 31) | 0x7f800000 | payload + return struct.pack('>I', vf) + else: + return struct.pack('>f', self.value) + def __preserve_write_binary__(self, encoder): encoder.buffer.append(0x82) - encoder.buffer.extend(struct.pack('>f', self.value)) + encoder.buffer.extend(self._to_bytes()) def __preserve_write_text__(self, formatter): - formatter.chunks.append(repr(self.value) + 'f') + if math.isnan(self.value) or math.isinf(self.value): + formatter.chunks.append('#xf"' + self._to_bytes().hex() + '"') + else: + formatter.chunks.append(repr(self.value) + 'f') + + @staticmethod + def from_bytes(bs): + vf = struct.unpack('>I', bs)[0] + if (vf & 0x7f800000) == 0x7f800000: + # NaN or inf. Preserve quiet/signalling bit by manually expanding to double-precision. + sign = vf >> 31 + payload = vf & 0x007fffff + dbs = struct.pack('>Q', (sign << 63) | 0x7ff0000000000000 | (payload << 29)) + return Float(struct.unpack('>d', dbs)[0]) + else: + return Float(struct.unpack('>f', bs)[0]) # FIXME: This regular expression is conservatively correct, but Anglo-chauvinistic. -RAW_SYMBOL_RE = re.compile(r'^[a-zA-Z~!$%^&*?_=+/.][-a-zA-Z~!$%^&*?_=+/.0-9]*$') +RAW_SYMBOL_RE = re.compile(r'^[-a-zA-Z0-9~!$%^&*?_=+/.]+$') class Symbol(object): def __init__(self, name): diff --git a/implementations/python/tests/test_compare.py b/implementations/python/tests/test_compare.py index 9bcf04e..569459d 100644 --- a/implementations/python/tests/test_compare.py +++ b/implementations/python/tests/test_compare.py @@ -1,9 +1,9 @@ -import unittest +from utils import PreservesTestCase from preserves import * from preserves.compare import * -class BasicCompareTests(unittest.TestCase): +class BasicCompareTests(PreservesTestCase): def test_eq_identity(self): self.assertTrue(eq(1, 1)) self.assertFalse(eq(1, 1.0)) diff --git a/implementations/python/tests/test_path.py b/implementations/python/tests/test_path.py index 2f8afa5..18c5279 100644 --- a/implementations/python/tests/test_path.py +++ b/implementations/python/tests/test_path.py @@ -1,30 +1,30 @@ -import unittest +from utils import PreservesTestCase from preserves import * from preserves.path import parse -class BasicPathTests(unittest.TestCase): +class BasicPathTests(PreservesTestCase): def test_identity(self): - self.assertEqual(parse('').exec(1), (1,)) - self.assertEqual(parse('').exec([]), ([],)) - self.assertEqual(parse('').exec(Record(Symbol('hi'), [])), (Record(Symbol('hi'), []),)) + self.assertPreservesEqual(parse('').exec(1), (1,)) + self.assertPreservesEqual(parse('').exec([]), ([],)) + self.assertPreservesEqual(parse('').exec(Record(Symbol('hi'), [])), (Record(Symbol('hi'), []),)) def test_children(self): - self.assertEqual(parse('/').exec([1, 2, 3]), (1, 2, 3)) - self.assertEqual(parse('/').exec([1, [2], 3]), (1, [2], 3)) - self.assertEqual(parse('/').exec(Record(Symbol('hi'), [1, [2], 3])), (1, [2], 3)) + self.assertPreservesEqual(parse('/').exec([1, 2, 3]), (1, 2, 3)) + self.assertPreservesEqual(parse('/').exec([1, [2], 3]), (1, [2], 3)) + self.assertPreservesEqual(parse('/').exec(Record(Symbol('hi'), [1, [2], 3])), (1, [2], 3)) def test_label(self): - self.assertEqual(parse('.^').exec([1, 2, 3]), ()) - self.assertEqual(parse('.^').exec([1, [2], 3]), ()) - self.assertEqual(parse('.^').exec(Record(Symbol('hi'), [1, [2], 3])), (Symbol('hi'),)) + self.assertPreservesEqual(parse('.^').exec([1, 2, 3]), ()) + self.assertPreservesEqual(parse('.^').exec([1, [2], 3]), ()) + self.assertPreservesEqual(parse('.^').exec(Record(Symbol('hi'), [1, [2], 3])), (Symbol('hi'),)) def test_count(self): - self.assertEqual(parse('').exec([ Record(Symbol('hi'), [1]), - Record(Symbol('no'), [2]), - Record(Symbol('hi'), [3]) ]), + self.assertPreservesEqual(parse('').exec([ Record(Symbol('hi'), [1]), + Record(Symbol('no'), [2]), + Record(Symbol('hi'), [3]) ]), (2,)) - self.assertEqual(parse('/ ').exec([ Record(Symbol('hi'), [1]), - Record(Symbol('no'), [2]), - Record(Symbol('hi'), [3]) ]), + self.assertPreservesEqual(parse('/ ').exec([ Record(Symbol('hi'), [1]), + Record(Symbol('no'), [2]), + Record(Symbol('hi'), [3]) ]), (1, 0, 1)) diff --git a/implementations/python/tests/test_preserves.py b/implementations/python/tests/test_preserves.py index 9dc2f9c..ecd2b6c 100644 --- a/implementations/python/tests/test_preserves.py +++ b/implementations/python/tests/test_preserves.py @@ -1,11 +1,12 @@ import numbers import os import sys -import unittest # Make `preserves` available for imports sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from utils import PreservesTestCase + from preserves import * from preserves.compat import basestring_, ord_ from preserves.values import _unwrap @@ -49,33 +50,33 @@ def _e(v): def _R(k, *args): return Record(Symbol(k), args) -class BinaryCodecTests(unittest.TestCase): +class BinaryCodecTests(PreservesTestCase): def _roundtrip(self, forward, expected, back=None, nondeterministic=False): if back is None: back = forward - self.assertEqual(_d(_e(forward)), back) - self.assertEqual(_d(_e(back)), back) - self.assertEqual(_d(expected), back) + self.assertPreservesEqual(_d(_e(forward)), back) + self.assertPreservesEqual(_d(_e(back)), back) + self.assertPreservesEqual(_d(expected), back) if not nondeterministic: actual = _e(forward) - self.assertEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected))) + self.assertPreservesEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected))) def test_decode_varint(self): with self.assertRaises(DecodeError): Decoder(_buf()).varint() - self.assertEqual(Decoder(_buf(0)).varint(), 0) - self.assertEqual(Decoder(_buf(10)).varint(), 10) - self.assertEqual(Decoder(_buf(100)).varint(), 100) - self.assertEqual(Decoder(_buf(200, 1)).varint(), 200) - self.assertEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300) - self.assertEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000) + self.assertPreservesEqual(Decoder(_buf(0)).varint(), 0) + self.assertPreservesEqual(Decoder(_buf(10)).varint(), 10) + self.assertPreservesEqual(Decoder(_buf(100)).varint(), 100) + self.assertPreservesEqual(Decoder(_buf(200, 1)).varint(), 200) + self.assertPreservesEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300) + self.assertPreservesEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000) def test_encode_varint(self): - self.assertEqual(_varint(0), _buf(0)) - self.assertEqual(_varint(10), _buf(10)) - self.assertEqual(_varint(100), _buf(100)) - self.assertEqual(_varint(200), _buf(200, 1)) - self.assertEqual(_varint(300), _buf(0b10101100, 0b00000010)) - self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3)) + self.assertPreservesEqual(_varint(0), _buf(0)) + self.assertPreservesEqual(_varint(10), _buf(10)) + self.assertPreservesEqual(_varint(100), _buf(100)) + self.assertPreservesEqual(_varint(200), _buf(200, 1)) + self.assertPreservesEqual(_varint(300), _buf(0b10101100, 0b00000010)) + self.assertPreservesEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3)) def test_simple_seq(self): self._roundtrip([1,2,3,4], _buf(0xb5, 0x91, 0x92, 0x93, 0x94, 0x84), back=(1,2,3,4)) @@ -157,7 +158,7 @@ class BinaryCodecTests(unittest.TestCase): # python 3 bs = _e(d.items()) self.assertRegex(_hex(bs), r) - self.assertEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)]) + self.assertPreservesEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)]) def test_long_sequence(self): self._roundtrip((False,) * 14, _buf(0xb5, b'\x80' * 14, 0x84)) @@ -172,9 +173,9 @@ class BinaryCodecTests(unittest.TestCase): a1 = Embedded(A(1)) a2 = Embedded(A(1)) self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id)) - self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id)) - self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86) - self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86) + self.assertPreservesEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id)) + self.assertPreservesEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86) + self.assertPreservesEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86) def test_decode_embedded_absent(self): with self.assertRaises(DecodeError): @@ -185,15 +186,15 @@ class BinaryCodecTests(unittest.TestCase): def enc(p): objects.append(p) return len(objects) - 1 - self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc), - b'\xb5\x86\x90\x86\x91\x84') + self.assertPreservesEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc), + b'\xb5\x86\x90\x86\x91\x84') def test_decode_embedded(self): objects = [123, 234] def dec(v): return objects[v] - self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec), - (Embedded(123), Embedded(234))) + self.assertPreservesEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec), + (Embedded(123), Embedded(234))) def load_binary_samples(): with open(os.path.join(os.path.dirname(__file__), 'samples.bin'), 'rb') as f: @@ -203,16 +204,16 @@ def load_text_samples(): with open(os.path.join(os.path.dirname(__file__), 'samples.pr'), 'rt') as f: return Parser(f.read(), include_annotations=True, parse_embedded=lambda x: x).next() -class TextCodecTests(unittest.TestCase): +class TextCodecTests(PreservesTestCase): def test_samples_bin_eq_txt(self): b = load_binary_samples() t = load_text_samples() - self.assertEqual(b, t) + self.assertPreservesEqual(b, t) def test_txt_roundtrip(self): b = load_binary_samples() s = stringify(b, format_embedded=lambda x: x) - self.assertEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b) + self.assertPreservesEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b) def add_method(d, tName, fn): if hasattr(fn, 'func_name'): @@ -254,14 +255,14 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm): entry = get_expected_values(tName, textForm) forward = entry['forward'] back = entry['back'] - def test_match_expected(self): self.assertEqual(textForm, back) - def test_roundtrip(self): self.assertEqual(self.DS(self.E(textForm)), back) - def test_forward(self): self.assertEqual(self.DS(self.E(forward)), back) - def test_back(self): self.assertEqual(self.DS(binaryForm), back) - def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm) - def test_encode(self): self.assertEqual(self.E(forward), binaryForm) - def test_encode_canonical(self): self.assertEqual(self.EC(annotatedTextForm), binaryForm) - def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm) + def test_match_expected(self): self.assertPreservesEqual(textForm, back) + def test_roundtrip(self): self.assertPreservesEqual(self.DS(self.E(textForm)), back) + def test_forward(self): self.assertPreservesEqual(self.DS(self.E(forward)), back) + def test_back(self): self.assertPreservesEqual(self.DS(binaryForm), back) + def test_back_ann(self): self.assertPreservesEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm) + def test_encode(self): self.assertPreservesEqual(self.E(forward), binaryForm) + def test_encode_canonical(self): self.assertPreservesEqual(self.EC(annotatedTextForm), binaryForm) + def test_encode_ann(self): self.assertPreservesEqual(self.E(annotatedTextForm), binaryForm) add_method(d, tName, test_match_expected) add_method(d, tName, test_roundtrip) add_method(d, tName, test_forward) @@ -284,7 +285,7 @@ def install_exn_test(d, tName, bs, check_proc): self.fail('did not fail as expected') add_method(d, tName, test_exn) -class CommonTestSuite(unittest.TestCase): +class CommonTestSuite(PreservesTestCase): TestCases = Record.makeConstructor('TestCases', 'cases') samples = load_binary_samples() @@ -325,7 +326,7 @@ class CommonTestSuite(unittest.TestCase): def EC(self, v): return encode(v, encode_embedded=lambda x: x, canonicalize=True) -class RecordTests(unittest.TestCase): +class RecordTests(PreservesTestCase): def test_getters(self): T = Record.makeConstructor('t', 'x y z') T2 = Record.makeConstructor('t', 'x y z') @@ -334,8 +335,8 @@ class RecordTests(unittest.TestCase): self.assertTrue(T.isClassOf(t)) self.assertTrue(T2.isClassOf(t)) self.assertFalse(U.isClassOf(t)) - self.assertEqual(T._x(t), 1) - self.assertEqual(T2._y(t), 2) - self.assertEqual(T._z(t), 3) + self.assertPreservesEqual(T._x(t), 1) + self.assertPreservesEqual(T2._y(t), 2) + self.assertPreservesEqual(T._z(t), 3) with self.assertRaises(TypeError): U._x(t) diff --git a/implementations/python/tests/test_schema.py b/implementations/python/tests/test_schema.py index 69e8688..115da71 100644 --- a/implementations/python/tests/test_schema.py +++ b/implementations/python/tests/test_schema.py @@ -1,4 +1,4 @@ -import unittest +from utils import PreservesTestCase from preserves import * from preserves.schema import meta, Compiler @@ -8,7 +8,7 @@ def literal_schema(modname, s): c.load_schema((Symbol(modname),), preserve(s)) return c.root -class BasicSchemaTests(unittest.TestCase): +class BasicSchemaTests(PreservesTestCase): def test_dictionary_literal(self): m = literal_schema( 's', @@ -22,7 +22,7 @@ class BasicSchemaTests(unittest.TestCase): }> ''')) self.assertEqual(m.s.C.decode({'core': Symbol('true')}), m.s.C()) - self.assertEqual(preserve(m.s.C()), {'core': Symbol('true')}) + self.assertPreservesEqual(preserve(m.s.C()), {'core': Symbol('true')}) def test_alternation_of_dictionary_literal(self): m = literal_schema( @@ -40,6 +40,6 @@ class BasicSchemaTests(unittest.TestCase): }> ''')) self.assertEqual(m.s.C.decode({'core': Symbol('true')}), m.s.C.core()) - self.assertEqual(preserve(m.s.C.core()), {'core': Symbol('true')}) + self.assertPreservesEqual(preserve(m.s.C.core()), {'core': Symbol('true')}) self.assertEqual(m.s.C.decode({'notcore': Symbol('true')}), m.s.C.notcore()) - self.assertEqual(preserve(m.s.C.notcore()), {'notcore': Symbol('true')}) + self.assertPreservesEqual(preserve(m.s.C.notcore()), {'notcore': Symbol('true')}) diff --git a/implementations/python/tests/utils.py b/implementations/python/tests/utils.py new file mode 100644 index 0000000..2487c6f --- /dev/null +++ b/implementations/python/tests/utils.py @@ -0,0 +1,9 @@ +import unittest + +from preserves import cmp + +class PreservesTestCase(unittest.TestCase): + def assertPreservesEqual(self, a, b, msg=None): + if msg is None: + msg = 'Expected %s to be Preserves-equal to %s' % (a, b) + self.assertTrue(cmp(a, b) == 0, msg)