Update Python implementation; repair comparison routines
This commit is contained in:
parent
8ff1c9441c
commit
67613877ce
|
@ -72,7 +72,7 @@ class Decoder(BinaryCodec):
|
|||
tag = self.nextbyte()
|
||||
if tag == 0x80: return self.wrap(False)
|
||||
if tag == 0x81: return self.wrap(True)
|
||||
if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
|
||||
if tag == 0x82: return self.wrap(Float.from_bytes(self.nextbytes(4)))
|
||||
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
|
||||
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
|
||||
if tag == 0x85:
|
||||
|
|
|
@ -2,7 +2,7 @@ import numbers
|
|||
from enum import Enum
|
||||
from functools import cmp_to_key
|
||||
|
||||
from .values import preserve, Float, Embedded, Record, Symbol
|
||||
from .values import preserve, Float, Embedded, Record, Symbol, cmp_floats, _unwrap
|
||||
from .compat import basestring_
|
||||
|
||||
class TypeNumber(Enum):
|
||||
|
@ -19,7 +19,7 @@ class TypeNumber(Enum):
|
|||
SET = 9
|
||||
DICTIONARY = 10
|
||||
|
||||
EMBEDDED = 10
|
||||
EMBEDDED = 11
|
||||
|
||||
def type_number(v):
|
||||
if hasattr(v, '__preserve__'):
|
||||
|
@ -84,12 +84,17 @@ def _item_key(item):
|
|||
return item[0]
|
||||
|
||||
def _eq(a, b):
|
||||
a = _unwrap(a)
|
||||
b = _unwrap(b)
|
||||
ta = type_number(a)
|
||||
tb = type_number(b)
|
||||
if ta != tb: return False
|
||||
|
||||
if ta == TypeNumber.DOUBLE:
|
||||
return cmp_floats(a, b) == 0
|
||||
|
||||
if ta == TypeNumber.EMBEDDED:
|
||||
return ta.embeddedValue == tb.embeddedValue
|
||||
return _eq(a.embeddedValue, b.embeddedValue)
|
||||
|
||||
if ta == TypeNumber.RECORD:
|
||||
return _eq(a.key, b.key) and _eq_sequences(a.fields, b.fields)
|
||||
|
@ -118,13 +123,18 @@ def _cmp_sequences(aa, bb):
|
|||
return len(aa) - len(bb)
|
||||
|
||||
def _cmp(a, b):
|
||||
a = _unwrap(a)
|
||||
b = _unwrap(b)
|
||||
ta = type_number(a)
|
||||
tb = type_number(b)
|
||||
if ta.value < tb.value: return -1
|
||||
if tb.value < ta.value: return 1
|
||||
|
||||
if ta == TypeNumber.DOUBLE:
|
||||
return cmp_floats(a, b)
|
||||
|
||||
if ta == TypeNumber.EMBEDDED:
|
||||
return _simplecmp(ta.embeddedValue, tb.embeddedValue)
|
||||
return _cmp(a.embeddedValue, b.embeddedValue)
|
||||
|
||||
if ta == TypeNumber.RECORD:
|
||||
v = _cmp(a.key, b.key)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import numbers
|
||||
import struct
|
||||
import base64
|
||||
import math
|
||||
|
||||
from .values import *
|
||||
from .error import *
|
||||
|
@ -9,6 +10,8 @@ from .binary import Decoder
|
|||
|
||||
class TextCodec(object): pass
|
||||
|
||||
NUMBER_RE = re.compile(r'^([-+]?\d+)(((\.\d+([eE][-+]?\d+)?)|([eE][-+]?\d+))([fF]?))?$')
|
||||
|
||||
class Parser(TextCodec):
|
||||
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=lambda x: x):
|
||||
super(Parser, self).__init__()
|
||||
|
@ -66,50 +69,6 @@ class Parser(TextCodec):
|
|||
return self.wrap(u''.join(s))
|
||||
s.append(c)
|
||||
|
||||
def read_intpart(self, acc, c):
|
||||
if c == '0':
|
||||
acc.append(c)
|
||||
else:
|
||||
self.read_digit1(acc, c)
|
||||
return self.read_fracexp(acc)
|
||||
|
||||
def read_fracexp(self, acc):
|
||||
is_float = False
|
||||
if self.peek() == '.':
|
||||
is_float = True
|
||||
acc.append(self.nextchar())
|
||||
self.read_digit1(acc, self.nextchar())
|
||||
if self.peek() in 'eE':
|
||||
acc.append(self.nextchar())
|
||||
return self.read_sign_and_exp(acc)
|
||||
else:
|
||||
return self.finish_number(acc, is_float)
|
||||
|
||||
def read_sign_and_exp(self, acc):
|
||||
if self.peek() in '+-':
|
||||
acc.append(self.nextchar())
|
||||
self.read_digit1(acc, self.nextchar())
|
||||
return self.finish_number(acc, True)
|
||||
|
||||
def finish_number(self, acc, is_float):
|
||||
if is_float:
|
||||
if self.peek() in 'fF':
|
||||
self.skip()
|
||||
return Float(float(u''.join(acc)))
|
||||
else:
|
||||
return float(u''.join(acc))
|
||||
else:
|
||||
return int(u''.join(acc))
|
||||
|
||||
def read_digit1(self, acc, c):
|
||||
if not c.isdigit():
|
||||
raise DecodeError('Incomplete number')
|
||||
acc.append(c)
|
||||
while not self._atend():
|
||||
if not self.peek().isdigit():
|
||||
break
|
||||
acc.append(self.nextchar())
|
||||
|
||||
def read_stringlike(self, terminator, hexescape, hexescaper):
|
||||
acc = []
|
||||
while True:
|
||||
|
@ -186,6 +145,16 @@ class Parser(TextCodec):
|
|||
if c == '=': continue
|
||||
acc.append(c)
|
||||
|
||||
def read_hex_float(self, bytecount):
|
||||
if self.nextchar() != '"':
|
||||
raise DecodeError('Missing open-double-quote in hex-encoded floating-point number')
|
||||
bs = self.read_hex_binary()
|
||||
if len(bs) != bytecount:
|
||||
raise DecodeError('Incorrect number of bytes in hex-encoded floating-point number')
|
||||
if bytecount == 4: return Float.from_bytes(bs)
|
||||
if bytecount == 8: return struct.unpack('>d', bs)[0]
|
||||
raise DecodeError('Unsupported byte count in hex-encoded floating-point number')
|
||||
|
||||
def upto(self, delimiter):
|
||||
vs = []
|
||||
while True:
|
||||
|
@ -208,14 +177,24 @@ class Parser(TextCodec):
|
|||
raise DecodeError('Missing expected key/value separator')
|
||||
acc.append(self.next())
|
||||
|
||||
def read_raw_symbol(self, acc):
|
||||
def read_raw_symbol_or_number(self, acc):
|
||||
while not self._atend():
|
||||
c = self.peek()
|
||||
if c.isspace() or c in '(){}[]<>";,@#:|':
|
||||
break
|
||||
self.skip()
|
||||
acc.append(c)
|
||||
return Symbol(u''.join(acc))
|
||||
acc = u''.join(acc)
|
||||
m = NUMBER_RE.match(acc)
|
||||
if m:
|
||||
if m[2] is None:
|
||||
return int(m[1])
|
||||
elif m[7] == '':
|
||||
return float(m[1] + m[3])
|
||||
else:
|
||||
return Float(float(m[1] + m[3]))
|
||||
else:
|
||||
return Symbol(acc)
|
||||
|
||||
def wrap(self, v):
|
||||
return Annotated(v) if self.include_annotations else v
|
||||
|
@ -223,12 +202,6 @@ class Parser(TextCodec):
|
|||
def next(self):
|
||||
self.skip_whitespace()
|
||||
c = self.peek()
|
||||
if c == '-':
|
||||
self.skip()
|
||||
return self.wrap(self.read_intpart(['-'], self.nextchar()))
|
||||
if c.isdigit():
|
||||
self.skip()
|
||||
return self.wrap(self.read_intpart([], c))
|
||||
if c == '"':
|
||||
self.skip()
|
||||
return self.wrap(self.read_string('"'))
|
||||
|
@ -251,9 +224,11 @@ class Parser(TextCodec):
|
|||
if c == '{': return self.wrap(frozenset(self.upto('}')))
|
||||
if c == '"': return self.wrap(self.read_literal_binary())
|
||||
if c == 'x':
|
||||
if self.nextchar() != '"':
|
||||
raise DecodeError('Expected open-quote at start of hex ByteString')
|
||||
return self.wrap(self.read_hex_binary())
|
||||
c = self.nextchar()
|
||||
if c == '"': return self.wrap(self.read_hex_binary())
|
||||
if c == 'f': return self.wrap(self.read_hex_float(4))
|
||||
if c == 'd': return self.wrap(self.read_hex_float(8))
|
||||
raise DecodeError('Invalid #x syntax')
|
||||
if c == '[': return self.wrap(self.read_base64_binary())
|
||||
if c == '=':
|
||||
old_ann = self.include_annotations
|
||||
|
@ -286,7 +261,7 @@ class Parser(TextCodec):
|
|||
if c in '>]}':
|
||||
raise DecodeError('Unexpected ' + c)
|
||||
self.skip()
|
||||
return self.wrap(self.read_raw_symbol([c]))
|
||||
return self.wrap(self.read_raw_symbol_or_number([c]))
|
||||
|
||||
def try_next(self):
|
||||
start = self.index
|
||||
|
@ -385,7 +360,10 @@ class Formatter(TextCodec):
|
|||
elif v is True:
|
||||
self.chunks.append('#t')
|
||||
elif isinstance(v, float):
|
||||
self.chunks.append(repr(v))
|
||||
if math.isnan(v) or math.isinf(v):
|
||||
self.chunks.append('#xd"' + struct.pack('>d', v).hex() + '"')
|
||||
else:
|
||||
self.chunks.append(repr(v))
|
||||
elif isinstance(v, numbers.Number):
|
||||
self.chunks.append('%d' % (v,))
|
||||
elif isinstance(v, bytes):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
import sys
|
||||
import struct
|
||||
import math
|
||||
|
||||
from .error import DecodeError
|
||||
|
||||
|
@ -9,6 +10,16 @@ def preserve(v):
|
|||
v = v.__preserve__()
|
||||
return v
|
||||
|
||||
def float_to_int(v):
|
||||
return struct.unpack('>Q', struct.pack('>d', v))[0]
|
||||
|
||||
def cmp_floats(a, b):
|
||||
a = float_to_int(a)
|
||||
b = float_to_int(b)
|
||||
if a & 0x8000000000000000: a = a ^ 0x7fffffffffffffff
|
||||
if b & 0x8000000000000000: b = b ^ 0x7fffffffffffffff
|
||||
return a - b
|
||||
|
||||
class Float(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
@ -16,7 +27,12 @@ class Float(object):
|
|||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
if other.__class__ is self.__class__:
|
||||
return self.value == other.value
|
||||
return cmp_floats(self.value, other.value) == 0
|
||||
|
||||
def __lt__(self, other):
|
||||
other = _unwrap(other)
|
||||
if other.__class__ is self.__class__:
|
||||
return cmp_floats(self.value, other.value) < 0
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
@ -27,15 +43,41 @@ class Float(object):
|
|||
def __repr__(self):
|
||||
return 'Float(' + repr(self.value) + ')'
|
||||
|
||||
def _to_bytes(self):
|
||||
if math.isnan(self.value) or math.isinf(self.value):
|
||||
dbs = struct.pack('>d', self.value)
|
||||
vd = struct.unpack('>Q', dbs)[0]
|
||||
sign = vd >> 63
|
||||
payload = (vd >> 29) & 0x007fffff
|
||||
vf = (sign << 31) | 0x7f800000 | payload
|
||||
return struct.pack('>I', vf)
|
||||
else:
|
||||
return struct.pack('>f', self.value)
|
||||
|
||||
def __preserve_write_binary__(self, encoder):
|
||||
encoder.buffer.append(0x82)
|
||||
encoder.buffer.extend(struct.pack('>f', self.value))
|
||||
encoder.buffer.extend(self._to_bytes())
|
||||
|
||||
def __preserve_write_text__(self, formatter):
|
||||
formatter.chunks.append(repr(self.value) + 'f')
|
||||
if math.isnan(self.value) or math.isinf(self.value):
|
||||
formatter.chunks.append('#xf"' + self._to_bytes().hex() + '"')
|
||||
else:
|
||||
formatter.chunks.append(repr(self.value) + 'f')
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(bs):
|
||||
vf = struct.unpack('>I', bs)[0]
|
||||
if (vf & 0x7f800000) == 0x7f800000:
|
||||
# NaN or inf. Preserve quiet/signalling bit by manually expanding to double-precision.
|
||||
sign = vf >> 31
|
||||
payload = vf & 0x007fffff
|
||||
dbs = struct.pack('>Q', (sign << 63) | 0x7ff0000000000000 | (payload << 29))
|
||||
return Float(struct.unpack('>d', dbs)[0])
|
||||
else:
|
||||
return Float(struct.unpack('>f', bs)[0])
|
||||
|
||||
# FIXME: This regular expression is conservatively correct, but Anglo-chauvinistic.
|
||||
RAW_SYMBOL_RE = re.compile(r'^[a-zA-Z~!$%^&*?_=+/.][-a-zA-Z~!$%^&*?_=+/.0-9]*$')
|
||||
RAW_SYMBOL_RE = re.compile(r'^[-a-zA-Z0-9~!$%^&*?_=+/.]+$')
|
||||
|
||||
class Symbol(object):
|
||||
def __init__(self, name):
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import unittest
|
||||
from utils import PreservesTestCase
|
||||
|
||||
from preserves import *
|
||||
from preserves.compare import *
|
||||
|
||||
class BasicCompareTests(unittest.TestCase):
|
||||
class BasicCompareTests(PreservesTestCase):
|
||||
def test_eq_identity(self):
|
||||
self.assertTrue(eq(1, 1))
|
||||
self.assertFalse(eq(1, 1.0))
|
||||
|
|
|
@ -1,30 +1,30 @@
|
|||
import unittest
|
||||
from utils import PreservesTestCase
|
||||
|
||||
from preserves import *
|
||||
from preserves.path import parse
|
||||
|
||||
class BasicPathTests(unittest.TestCase):
|
||||
class BasicPathTests(PreservesTestCase):
|
||||
def test_identity(self):
|
||||
self.assertEqual(parse('').exec(1), (1,))
|
||||
self.assertEqual(parse('').exec([]), ([],))
|
||||
self.assertEqual(parse('').exec(Record(Symbol('hi'), [])), (Record(Symbol('hi'), []),))
|
||||
self.assertPreservesEqual(parse('').exec(1), (1,))
|
||||
self.assertPreservesEqual(parse('').exec([]), ([],))
|
||||
self.assertPreservesEqual(parse('').exec(Record(Symbol('hi'), [])), (Record(Symbol('hi'), []),))
|
||||
|
||||
def test_children(self):
|
||||
self.assertEqual(parse('/').exec([1, 2, 3]), (1, 2, 3))
|
||||
self.assertEqual(parse('/').exec([1, [2], 3]), (1, [2], 3))
|
||||
self.assertEqual(parse('/').exec(Record(Symbol('hi'), [1, [2], 3])), (1, [2], 3))
|
||||
self.assertPreservesEqual(parse('/').exec([1, 2, 3]), (1, 2, 3))
|
||||
self.assertPreservesEqual(parse('/').exec([1, [2], 3]), (1, [2], 3))
|
||||
self.assertPreservesEqual(parse('/').exec(Record(Symbol('hi'), [1, [2], 3])), (1, [2], 3))
|
||||
|
||||
def test_label(self):
|
||||
self.assertEqual(parse('.^').exec([1, 2, 3]), ())
|
||||
self.assertEqual(parse('.^').exec([1, [2], 3]), ())
|
||||
self.assertEqual(parse('.^').exec(Record(Symbol('hi'), [1, [2], 3])), (Symbol('hi'),))
|
||||
self.assertPreservesEqual(parse('.^').exec([1, 2, 3]), ())
|
||||
self.assertPreservesEqual(parse('.^').exec([1, [2], 3]), ())
|
||||
self.assertPreservesEqual(parse('.^').exec(Record(Symbol('hi'), [1, [2], 3])), (Symbol('hi'),))
|
||||
|
||||
def test_count(self):
|
||||
self.assertEqual(parse('<count / ^ hi>').exec([ Record(Symbol('hi'), [1]),
|
||||
Record(Symbol('no'), [2]),
|
||||
Record(Symbol('hi'), [3]) ]),
|
||||
self.assertPreservesEqual(parse('<count / ^ hi>').exec([ Record(Symbol('hi'), [1]),
|
||||
Record(Symbol('no'), [2]),
|
||||
Record(Symbol('hi'), [3]) ]),
|
||||
(2,))
|
||||
self.assertEqual(parse('/ <count ^ hi>').exec([ Record(Symbol('hi'), [1]),
|
||||
Record(Symbol('no'), [2]),
|
||||
Record(Symbol('hi'), [3]) ]),
|
||||
self.assertPreservesEqual(parse('/ <count ^ hi>').exec([ Record(Symbol('hi'), [1]),
|
||||
Record(Symbol('no'), [2]),
|
||||
Record(Symbol('hi'), [3]) ]),
|
||||
(1, 0, 1))
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import numbers
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
# Make `preserves` available for imports
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from utils import PreservesTestCase
|
||||
|
||||
from preserves import *
|
||||
from preserves.compat import basestring_, ord_
|
||||
from preserves.values import _unwrap
|
||||
|
@ -49,33 +50,33 @@ def _e(v):
|
|||
def _R(k, *args):
|
||||
return Record(Symbol(k), args)
|
||||
|
||||
class BinaryCodecTests(unittest.TestCase):
|
||||
class BinaryCodecTests(PreservesTestCase):
|
||||
def _roundtrip(self, forward, expected, back=None, nondeterministic=False):
|
||||
if back is None: back = forward
|
||||
self.assertEqual(_d(_e(forward)), back)
|
||||
self.assertEqual(_d(_e(back)), back)
|
||||
self.assertEqual(_d(expected), back)
|
||||
self.assertPreservesEqual(_d(_e(forward)), back)
|
||||
self.assertPreservesEqual(_d(_e(back)), back)
|
||||
self.assertPreservesEqual(_d(expected), back)
|
||||
if not nondeterministic:
|
||||
actual = _e(forward)
|
||||
self.assertEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected)))
|
||||
self.assertPreservesEqual(actual, expected, '%s != %s' % (_hex(actual), _hex(expected)))
|
||||
|
||||
def test_decode_varint(self):
|
||||
with self.assertRaises(DecodeError):
|
||||
Decoder(_buf()).varint()
|
||||
self.assertEqual(Decoder(_buf(0)).varint(), 0)
|
||||
self.assertEqual(Decoder(_buf(10)).varint(), 10)
|
||||
self.assertEqual(Decoder(_buf(100)).varint(), 100)
|
||||
self.assertEqual(Decoder(_buf(200, 1)).varint(), 200)
|
||||
self.assertEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300)
|
||||
self.assertEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000)
|
||||
self.assertPreservesEqual(Decoder(_buf(0)).varint(), 0)
|
||||
self.assertPreservesEqual(Decoder(_buf(10)).varint(), 10)
|
||||
self.assertPreservesEqual(Decoder(_buf(100)).varint(), 100)
|
||||
self.assertPreservesEqual(Decoder(_buf(200, 1)).varint(), 200)
|
||||
self.assertPreservesEqual(Decoder(_buf(0b10101100, 0b00000010)).varint(), 300)
|
||||
self.assertPreservesEqual(Decoder(_buf(128, 148, 235, 220, 3)).varint(), 1000000000)
|
||||
|
||||
def test_encode_varint(self):
|
||||
self.assertEqual(_varint(0), _buf(0))
|
||||
self.assertEqual(_varint(10), _buf(10))
|
||||
self.assertEqual(_varint(100), _buf(100))
|
||||
self.assertEqual(_varint(200), _buf(200, 1))
|
||||
self.assertEqual(_varint(300), _buf(0b10101100, 0b00000010))
|
||||
self.assertEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3))
|
||||
self.assertPreservesEqual(_varint(0), _buf(0))
|
||||
self.assertPreservesEqual(_varint(10), _buf(10))
|
||||
self.assertPreservesEqual(_varint(100), _buf(100))
|
||||
self.assertPreservesEqual(_varint(200), _buf(200, 1))
|
||||
self.assertPreservesEqual(_varint(300), _buf(0b10101100, 0b00000010))
|
||||
self.assertPreservesEqual(_varint(1000000000), _buf(128, 148, 235, 220, 3))
|
||||
|
||||
def test_simple_seq(self):
|
||||
self._roundtrip([1,2,3,4], _buf(0xb5, 0x91, 0x92, 0x93, 0x94, 0x84), back=(1,2,3,4))
|
||||
|
@ -157,7 +158,7 @@ class BinaryCodecTests(unittest.TestCase):
|
|||
# python 3
|
||||
bs = _e(d.items())
|
||||
self.assertRegex(_hex(bs), r)
|
||||
self.assertEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)])
|
||||
self.assertPreservesEqual(sorted(_d(bs)), [(u'a', 1), (u'b', 2), (u'c', 3)])
|
||||
|
||||
def test_long_sequence(self):
|
||||
self._roundtrip((False,) * 14, _buf(0xb5, b'\x80' * 14, 0x84))
|
||||
|
@ -172,9 +173,9 @@ class BinaryCodecTests(unittest.TestCase):
|
|||
a1 = Embedded(A(1))
|
||||
a2 = Embedded(A(1))
|
||||
self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id))
|
||||
self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id))
|
||||
self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86)
|
||||
self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86)
|
||||
self.assertPreservesEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id))
|
||||
self.assertPreservesEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86)
|
||||
self.assertPreservesEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86)
|
||||
|
||||
def test_decode_embedded_absent(self):
|
||||
with self.assertRaises(DecodeError):
|
||||
|
@ -185,15 +186,15 @@ class BinaryCodecTests(unittest.TestCase):
|
|||
def enc(p):
|
||||
objects.append(p)
|
||||
return len(objects) - 1
|
||||
self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc),
|
||||
b'\xb5\x86\x90\x86\x91\x84')
|
||||
self.assertPreservesEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc),
|
||||
b'\xb5\x86\x90\x86\x91\x84')
|
||||
|
||||
def test_decode_embedded(self):
|
||||
objects = [123, 234]
|
||||
def dec(v):
|
||||
return objects[v]
|
||||
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
|
||||
(Embedded(123), Embedded(234)))
|
||||
self.assertPreservesEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
|
||||
(Embedded(123), Embedded(234)))
|
||||
|
||||
def load_binary_samples():
|
||||
with open(os.path.join(os.path.dirname(__file__), 'samples.bin'), 'rb') as f:
|
||||
|
@ -203,16 +204,16 @@ def load_text_samples():
|
|||
with open(os.path.join(os.path.dirname(__file__), 'samples.pr'), 'rt') as f:
|
||||
return Parser(f.read(), include_annotations=True, parse_embedded=lambda x: x).next()
|
||||
|
||||
class TextCodecTests(unittest.TestCase):
|
||||
class TextCodecTests(PreservesTestCase):
|
||||
def test_samples_bin_eq_txt(self):
|
||||
b = load_binary_samples()
|
||||
t = load_text_samples()
|
||||
self.assertEqual(b, t)
|
||||
self.assertPreservesEqual(b, t)
|
||||
|
||||
def test_txt_roundtrip(self):
|
||||
b = load_binary_samples()
|
||||
s = stringify(b, format_embedded=lambda x: x)
|
||||
self.assertEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b)
|
||||
self.assertPreservesEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b)
|
||||
|
||||
def add_method(d, tName, fn):
|
||||
if hasattr(fn, 'func_name'):
|
||||
|
@ -254,14 +255,14 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm):
|
|||
entry = get_expected_values(tName, textForm)
|
||||
forward = entry['forward']
|
||||
back = entry['back']
|
||||
def test_match_expected(self): self.assertEqual(textForm, back)
|
||||
def test_roundtrip(self): self.assertEqual(self.DS(self.E(textForm)), back)
|
||||
def test_forward(self): self.assertEqual(self.DS(self.E(forward)), back)
|
||||
def test_back(self): self.assertEqual(self.DS(binaryForm), back)
|
||||
def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm)
|
||||
def test_encode(self): self.assertEqual(self.E(forward), binaryForm)
|
||||
def test_encode_canonical(self): self.assertEqual(self.EC(annotatedTextForm), binaryForm)
|
||||
def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm)
|
||||
def test_match_expected(self): self.assertPreservesEqual(textForm, back)
|
||||
def test_roundtrip(self): self.assertPreservesEqual(self.DS(self.E(textForm)), back)
|
||||
def test_forward(self): self.assertPreservesEqual(self.DS(self.E(forward)), back)
|
||||
def test_back(self): self.assertPreservesEqual(self.DS(binaryForm), back)
|
||||
def test_back_ann(self): self.assertPreservesEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm)
|
||||
def test_encode(self): self.assertPreservesEqual(self.E(forward), binaryForm)
|
||||
def test_encode_canonical(self): self.assertPreservesEqual(self.EC(annotatedTextForm), binaryForm)
|
||||
def test_encode_ann(self): self.assertPreservesEqual(self.E(annotatedTextForm), binaryForm)
|
||||
add_method(d, tName, test_match_expected)
|
||||
add_method(d, tName, test_roundtrip)
|
||||
add_method(d, tName, test_forward)
|
||||
|
@ -284,7 +285,7 @@ def install_exn_test(d, tName, bs, check_proc):
|
|||
self.fail('did not fail as expected')
|
||||
add_method(d, tName, test_exn)
|
||||
|
||||
class CommonTestSuite(unittest.TestCase):
|
||||
class CommonTestSuite(PreservesTestCase):
|
||||
TestCases = Record.makeConstructor('TestCases', 'cases')
|
||||
|
||||
samples = load_binary_samples()
|
||||
|
@ -325,7 +326,7 @@ class CommonTestSuite(unittest.TestCase):
|
|||
def EC(self, v):
|
||||
return encode(v, encode_embedded=lambda x: x, canonicalize=True)
|
||||
|
||||
class RecordTests(unittest.TestCase):
|
||||
class RecordTests(PreservesTestCase):
|
||||
def test_getters(self):
|
||||
T = Record.makeConstructor('t', 'x y z')
|
||||
T2 = Record.makeConstructor('t', 'x y z')
|
||||
|
@ -334,8 +335,8 @@ class RecordTests(unittest.TestCase):
|
|||
self.assertTrue(T.isClassOf(t))
|
||||
self.assertTrue(T2.isClassOf(t))
|
||||
self.assertFalse(U.isClassOf(t))
|
||||
self.assertEqual(T._x(t), 1)
|
||||
self.assertEqual(T2._y(t), 2)
|
||||
self.assertEqual(T._z(t), 3)
|
||||
self.assertPreservesEqual(T._x(t), 1)
|
||||
self.assertPreservesEqual(T2._y(t), 2)
|
||||
self.assertPreservesEqual(T._z(t), 3)
|
||||
with self.assertRaises(TypeError):
|
||||
U._x(t)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import unittest
|
||||
from utils import PreservesTestCase
|
||||
|
||||
from preserves import *
|
||||
from preserves.schema import meta, Compiler
|
||||
|
@ -8,7 +8,7 @@ def literal_schema(modname, s):
|
|||
c.load_schema((Symbol(modname),), preserve(s))
|
||||
return c.root
|
||||
|
||||
class BasicSchemaTests(unittest.TestCase):
|
||||
class BasicSchemaTests(PreservesTestCase):
|
||||
def test_dictionary_literal(self):
|
||||
m = literal_schema(
|
||||
's',
|
||||
|
@ -22,7 +22,7 @@ class BasicSchemaTests(unittest.TestCase):
|
|||
}>
|
||||
'''))
|
||||
self.assertEqual(m.s.C.decode({'core': Symbol('true')}), m.s.C())
|
||||
self.assertEqual(preserve(m.s.C()), {'core': Symbol('true')})
|
||||
self.assertPreservesEqual(preserve(m.s.C()), {'core': Symbol('true')})
|
||||
|
||||
def test_alternation_of_dictionary_literal(self):
|
||||
m = literal_schema(
|
||||
|
@ -40,6 +40,6 @@ class BasicSchemaTests(unittest.TestCase):
|
|||
}>
|
||||
'''))
|
||||
self.assertEqual(m.s.C.decode({'core': Symbol('true')}), m.s.C.core())
|
||||
self.assertEqual(preserve(m.s.C.core()), {'core': Symbol('true')})
|
||||
self.assertPreservesEqual(preserve(m.s.C.core()), {'core': Symbol('true')})
|
||||
self.assertEqual(m.s.C.decode({'notcore': Symbol('true')}), m.s.C.notcore())
|
||||
self.assertEqual(preserve(m.s.C.notcore()), {'notcore': Symbol('true')})
|
||||
self.assertPreservesEqual(preserve(m.s.C.notcore()), {'notcore': Symbol('true')})
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
import unittest
|
||||
|
||||
from preserves import cmp
|
||||
|
||||
class PreservesTestCase(unittest.TestCase):
|
||||
def assertPreservesEqual(self, a, b, msg=None):
|
||||
if msg is None:
|
||||
msg = 'Expected %s to be Preserves-equal to %s' % (a, b)
|
||||
self.assertTrue(cmp(a, b) == 0, msg)
|
Loading…
Reference in New Issue