Preserves-order for Python
This commit is contained in:
parent
00eb9e97b6
commit
c9c973ce9c
|
@ -1,6 +1,8 @@
|
||||||
from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve
|
from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve
|
||||||
from .values import Annotated, is_annotated, strip_annotations, annotate
|
from .values import Annotated, is_annotated, strip_annotations, annotate
|
||||||
|
|
||||||
|
from .compare import cmp
|
||||||
|
|
||||||
from .error import DecodeError, EncodeError, ShortPacket
|
from .error import DecodeError, EncodeError, ShortPacket
|
||||||
|
|
||||||
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode, canonicalize
|
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode, canonicalize
|
||||||
|
@ -8,7 +10,7 @@ from .text import Parser, Formatter, parse, parse_with_annotations, stringify
|
||||||
|
|
||||||
from .merge import merge
|
from .merge import merge
|
||||||
|
|
||||||
from . import fold
|
from . import fold, compare
|
||||||
|
|
||||||
loads = parse
|
loads = parse
|
||||||
dumps = stringify
|
dumps = stringify
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
import numbers
|
||||||
|
from enum import Enum
|
||||||
|
from functools import cmp_to_key
|
||||||
|
|
||||||
|
from .values import preserve, Float, Embedded, Record, Symbol
|
||||||
|
from .compat import basestring_
|
||||||
|
|
||||||
|
class TypeNumber(Enum):
|
||||||
|
BOOL = 0
|
||||||
|
FLOAT = 1
|
||||||
|
DOUBLE = 2
|
||||||
|
SIGNED_INTEGER = 3
|
||||||
|
STRING = 4
|
||||||
|
BYTE_STRING = 5
|
||||||
|
SYMBOL = 6
|
||||||
|
|
||||||
|
RECORD = 7
|
||||||
|
SEQUENCE = 8
|
||||||
|
SET = 9
|
||||||
|
DICTIONARY = 10
|
||||||
|
|
||||||
|
EMBEDDED = 10
|
||||||
|
|
||||||
|
def type_number(v):
|
||||||
|
if hasattr(v, '__preserve__'):
|
||||||
|
raise ValueError('type_number expects Preserves value; use preserve()')
|
||||||
|
|
||||||
|
if isinstance(v, bool): return TypeNumber.BOOL
|
||||||
|
if isinstance(v, Float): return TypeNumber.FLOAT
|
||||||
|
if isinstance(v, float): return TypeNumber.DOUBLE
|
||||||
|
if isinstance(v, numbers.Number): return TypeNumber.SIGNED_INTEGER
|
||||||
|
if isinstance(v, basestring_): return TypeNumber.STRING
|
||||||
|
if isinstance(v, bytes): return TypeNumber.BYTE_STRING
|
||||||
|
if isinstance(v, Symbol): return TypeNumber.SYMBOL
|
||||||
|
|
||||||
|
if isinstance(v, Record): return TypeNumber.RECORD
|
||||||
|
if isinstance(v, list) or isinstance(v, tuple): return TypeNumber.SEQUENCE
|
||||||
|
if isinstance(v, set) or isinstance(v, frozenset): return TypeNumber.SET
|
||||||
|
if isinstance(v, dict): return TypeNumber.DICTIONARY
|
||||||
|
|
||||||
|
if isinstance(v, Embedded): return TypeNumber.EMBEDDED
|
||||||
|
|
||||||
|
try:
|
||||||
|
i = iter(v)
|
||||||
|
except TypeError:
|
||||||
|
i = None
|
||||||
|
if i is None:
|
||||||
|
raise ValueError('Invalid Preserves value in type_number')
|
||||||
|
else:
|
||||||
|
return TypeNumber.SEQUENCE
|
||||||
|
|
||||||
|
def cmp(a, b):
|
||||||
|
return _cmp(preserve(a), preserve(b))
|
||||||
|
|
||||||
|
def lt(a, b):
|
||||||
|
return cmp(a, b) < 0
|
||||||
|
|
||||||
|
def le(a, b):
|
||||||
|
return cmp(a, b) <= 0
|
||||||
|
|
||||||
|
def eq(a, b):
|
||||||
|
return _eq(preserve(a), preserve(b))
|
||||||
|
|
||||||
|
key = cmp_to_key(cmp)
|
||||||
|
_key = key
|
||||||
|
|
||||||
|
_sorted = sorted
|
||||||
|
def sorted(vs, /, *, key=lambda x: x, reverse=False):
|
||||||
|
return _sorted(vs, key=lambda x: _key(key(x)), reverse=reverse)
|
||||||
|
|
||||||
|
def sorted_items(d):
|
||||||
|
return sorted(d.items(), key=_item_key)
|
||||||
|
|
||||||
|
def _eq_sequences(aa, bb):
|
||||||
|
aa = list(aa)
|
||||||
|
bb = list(bb)
|
||||||
|
n = len(aa)
|
||||||
|
if len(bb) != n: return False
|
||||||
|
for i in range(n):
|
||||||
|
if not _eq(aa[i], bb[i]): return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _item_key(item):
|
||||||
|
return item[0]
|
||||||
|
|
||||||
|
def _eq(a, b):
|
||||||
|
ta = type_number(a)
|
||||||
|
tb = type_number(b)
|
||||||
|
if ta != tb: return False
|
||||||
|
|
||||||
|
if ta == TypeNumber.EMBEDDED:
|
||||||
|
return ta.embeddedValue == tb.embeddedValue
|
||||||
|
|
||||||
|
if ta == TypeNumber.RECORD:
|
||||||
|
return _eq(a.key, b.key) and _eq_sequences(a.fields, b.fields)
|
||||||
|
|
||||||
|
if ta == TypeNumber.SEQUENCE:
|
||||||
|
return _eq_sequences(a, b)
|
||||||
|
|
||||||
|
if ta == TypeNumber.SET:
|
||||||
|
return _eq_sequences(sorted(a), sorted(b))
|
||||||
|
|
||||||
|
if ta == TypeNumber.DICTIONARY:
|
||||||
|
return _eq_sequences(sorted_items(a), sorted_items(b))
|
||||||
|
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
def _simplecmp(a, b):
|
||||||
|
return (a > b) - (a < b)
|
||||||
|
|
||||||
|
def _cmp_sequences(aa, bb):
|
||||||
|
aa = list(aa)
|
||||||
|
bb = list(bb)
|
||||||
|
n = min(len(aa), len(bb))
|
||||||
|
for i in range(n):
|
||||||
|
v = _cmp(aa[i], bb[i])
|
||||||
|
if v != 0: return v
|
||||||
|
return len(aa) - len(bb)
|
||||||
|
|
||||||
|
def _cmp(a, b):
|
||||||
|
ta = type_number(a)
|
||||||
|
tb = type_number(b)
|
||||||
|
if ta.value < tb.value: return -1
|
||||||
|
if tb.value < ta.value: return 1
|
||||||
|
|
||||||
|
if ta == TypeNumber.EMBEDDED:
|
||||||
|
return _simplecmp(ta.embeddedValue, tb.embeddedValue)
|
||||||
|
|
||||||
|
if ta == TypeNumber.RECORD:
|
||||||
|
v = _cmp(a.key, b.key)
|
||||||
|
if v != 0: return v
|
||||||
|
return _cmp_sequences(a.fields, b.fields)
|
||||||
|
|
||||||
|
if ta == TypeNumber.SEQUENCE:
|
||||||
|
return _cmp_sequences(a, b)
|
||||||
|
|
||||||
|
if ta == TypeNumber.SET:
|
||||||
|
return _cmp_sequences(sorted(a), sorted(b))
|
||||||
|
|
||||||
|
if ta == TypeNumber.DICTIONARY:
|
||||||
|
return _cmp_sequences(sorted_items(a), sorted_items(b))
|
||||||
|
|
||||||
|
return _simplecmp(a, b)
|
|
@ -48,6 +48,18 @@ class Symbol(object):
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.name < other.name
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
return self.name <= other.name
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
return self.name > other.name
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
return self.name >= other.name
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.name)
|
return hash(self.name)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from preserves import *
|
||||||
|
from preserves.compare import *
|
||||||
|
|
||||||
|
class BasicCompareTests(unittest.TestCase):
|
||||||
|
def test_eq_identity(self):
|
||||||
|
self.assertTrue(eq(1, 1))
|
||||||
|
self.assertFalse(eq(1, 1.0))
|
||||||
|
self.assertTrue(eq([], []))
|
||||||
|
self.assertTrue(eq(Record(Symbol('hi'), []), Record(Symbol('hi'), [])))
|
||||||
|
|
||||||
|
def test_cmp_identity(self):
|
||||||
|
self.assertEqual(cmp(1, 1), 0)
|
||||||
|
self.assertEqual(cmp(1, 1.0), 1)
|
||||||
|
self.assertEqual(cmp(1.0, 1), -1)
|
||||||
|
self.assertEqual(cmp([], []), 0)
|
||||||
|
self.assertEqual(cmp([], {}), -1)
|
||||||
|
self.assertEqual(cmp(Record(Symbol('hi'), []), Record(Symbol('hi'), [])), 0)
|
Loading…
Reference in New Issue