diff --git a/implementations/python/preserves/__init__.py b/implementations/python/preserves/__init__.py index 3ec1b12..9e2b391 100644 --- a/implementations/python/preserves/__init__.py +++ b/implementations/python/preserves/__init__.py @@ -1,6 +1,8 @@ from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve from .values import Annotated, is_annotated, strip_annotations, annotate +from .compare import cmp + from .error import DecodeError, EncodeError, ShortPacket from .binary import Decoder, Encoder, decode, decode_with_annotations, encode, canonicalize @@ -8,7 +10,7 @@ from .text import Parser, Formatter, parse, parse_with_annotations, stringify from .merge import merge -from . import fold +from . import fold, compare loads = parse dumps = stringify diff --git a/implementations/python/preserves/compare.py b/implementations/python/preserves/compare.py new file mode 100644 index 0000000..07aed24 --- /dev/null +++ b/implementations/python/preserves/compare.py @@ -0,0 +1,143 @@ +import numbers +from enum import Enum +from functools import cmp_to_key + +from .values import preserve, Float, Embedded, Record, Symbol +from .compat import basestring_ + +class TypeNumber(Enum): + BOOL = 0 + FLOAT = 1 + DOUBLE = 2 + SIGNED_INTEGER = 3 + STRING = 4 + BYTE_STRING = 5 + SYMBOL = 6 + + RECORD = 7 + SEQUENCE = 8 + SET = 9 + DICTIONARY = 10 + + EMBEDDED = 10 + +def type_number(v): + if hasattr(v, '__preserve__'): + raise ValueError('type_number expects Preserves value; use preserve()') + + if isinstance(v, bool): return TypeNumber.BOOL + if isinstance(v, Float): return TypeNumber.FLOAT + if isinstance(v, float): return TypeNumber.DOUBLE + if isinstance(v, numbers.Number): return TypeNumber.SIGNED_INTEGER + if isinstance(v, basestring_): return TypeNumber.STRING + if isinstance(v, bytes): return TypeNumber.BYTE_STRING + if isinstance(v, Symbol): return TypeNumber.SYMBOL + + if isinstance(v, Record): return TypeNumber.RECORD + if isinstance(v, list) or isinstance(v, tuple): return TypeNumber.SEQUENCE + if isinstance(v, set) or isinstance(v, frozenset): return TypeNumber.SET + if isinstance(v, dict): return TypeNumber.DICTIONARY + + if isinstance(v, Embedded): return TypeNumber.EMBEDDED + + try: + i = iter(v) + except TypeError: + i = None + if i is None: + raise ValueError('Invalid Preserves value in type_number') + else: + return TypeNumber.SEQUENCE + +def cmp(a, b): + return _cmp(preserve(a), preserve(b)) + +def lt(a, b): + return cmp(a, b) < 0 + +def le(a, b): + return cmp(a, b) <= 0 + +def eq(a, b): + return _eq(preserve(a), preserve(b)) + +key = cmp_to_key(cmp) +_key = key + +_sorted = sorted +def sorted(vs, /, *, key=lambda x: x, reverse=False): + return _sorted(vs, key=lambda x: _key(key(x)), reverse=reverse) + +def sorted_items(d): + return sorted(d.items(), key=_item_key) + +def _eq_sequences(aa, bb): + aa = list(aa) + bb = list(bb) + n = len(aa) + if len(bb) != n: return False + for i in range(n): + if not _eq(aa[i], bb[i]): return False + return True + +def _item_key(item): + return item[0] + +def _eq(a, b): + ta = type_number(a) + tb = type_number(b) + if ta != tb: return False + + if ta == TypeNumber.EMBEDDED: + return ta.embeddedValue == tb.embeddedValue + + if ta == TypeNumber.RECORD: + return _eq(a.key, b.key) and _eq_sequences(a.fields, b.fields) + + if ta == TypeNumber.SEQUENCE: + return _eq_sequences(a, b) + + if ta == TypeNumber.SET: + return _eq_sequences(sorted(a), sorted(b)) + + if ta == TypeNumber.DICTIONARY: + return _eq_sequences(sorted_items(a), sorted_items(b)) + + return a == b + +def _simplecmp(a, b): + return (a > b) - (a < b) + +def _cmp_sequences(aa, bb): + aa = list(aa) + bb = list(bb) + n = min(len(aa), len(bb)) + for i in range(n): + v = _cmp(aa[i], bb[i]) + if v != 0: return v + return len(aa) - len(bb) + +def _cmp(a, b): + ta = type_number(a) + tb = type_number(b) + if ta.value < tb.value: return -1 + if tb.value < ta.value: return 1 + + if ta == TypeNumber.EMBEDDED: + return _simplecmp(ta.embeddedValue, tb.embeddedValue) + + if ta == TypeNumber.RECORD: + v = _cmp(a.key, b.key) + if v != 0: return v + return _cmp_sequences(a.fields, b.fields) + + if ta == TypeNumber.SEQUENCE: + return _cmp_sequences(a, b) + + if ta == TypeNumber.SET: + return _cmp_sequences(sorted(a), sorted(b)) + + if ta == TypeNumber.DICTIONARY: + return _cmp_sequences(sorted_items(a), sorted_items(b)) + + return _simplecmp(a, b) diff --git a/implementations/python/preserves/values.py b/implementations/python/preserves/values.py index 67be60d..b96bba3 100644 --- a/implementations/python/preserves/values.py +++ b/implementations/python/preserves/values.py @@ -48,6 +48,18 @@ class Symbol(object): def __ne__(self, other): return not self.__eq__(other) + def __lt__(self, other): + return self.name < other.name + + def __le__(self, other): + return self.name <= other.name + + def __gt__(self, other): + return self.name > other.name + + def __ge__(self, other): + return self.name >= other.name + def __hash__(self): return hash(self.name) diff --git a/implementations/python/tests/test_compare.py b/implementations/python/tests/test_compare.py new file mode 100644 index 0000000..9bcf04e --- /dev/null +++ b/implementations/python/tests/test_compare.py @@ -0,0 +1,19 @@ +import unittest + +from preserves import * +from preserves.compare import * + +class BasicCompareTests(unittest.TestCase): + def test_eq_identity(self): + self.assertTrue(eq(1, 1)) + self.assertFalse(eq(1, 1.0)) + self.assertTrue(eq([], [])) + self.assertTrue(eq(Record(Symbol('hi'), []), Record(Symbol('hi'), []))) + + def test_cmp_identity(self): + self.assertEqual(cmp(1, 1), 0) + self.assertEqual(cmp(1, 1.0), 1) + self.assertEqual(cmp(1.0, 1), -1) + self.assertEqual(cmp([], []), 0) + self.assertEqual(cmp([], {}), -1) + self.assertEqual(cmp(Record(Symbol('hi'), []), Record(Symbol('hi'), [])), 0)