144 lines
3.5 KiB
Python
144 lines
3.5 KiB
Python
import numbers
|
|
from enum import Enum
|
|
from functools import cmp_to_key
|
|
|
|
from .values import preserve, Float, Embedded, Record, Symbol
|
|
from .compat import basestring_
|
|
|
|
class TypeNumber(Enum):
|
|
BOOL = 0
|
|
FLOAT = 1
|
|
DOUBLE = 2
|
|
SIGNED_INTEGER = 3
|
|
STRING = 4
|
|
BYTE_STRING = 5
|
|
SYMBOL = 6
|
|
|
|
RECORD = 7
|
|
SEQUENCE = 8
|
|
SET = 9
|
|
DICTIONARY = 10
|
|
|
|
EMBEDDED = 10
|
|
|
|
def type_number(v):
|
|
if hasattr(v, '__preserve__'):
|
|
raise ValueError('type_number expects Preserves value; use preserve()')
|
|
|
|
if isinstance(v, bool): return TypeNumber.BOOL
|
|
if isinstance(v, Float): return TypeNumber.FLOAT
|
|
if isinstance(v, float): return TypeNumber.DOUBLE
|
|
if isinstance(v, numbers.Number): return TypeNumber.SIGNED_INTEGER
|
|
if isinstance(v, basestring_): return TypeNumber.STRING
|
|
if isinstance(v, bytes): return TypeNumber.BYTE_STRING
|
|
if isinstance(v, Symbol): return TypeNumber.SYMBOL
|
|
|
|
if isinstance(v, Record): return TypeNumber.RECORD
|
|
if isinstance(v, list) or isinstance(v, tuple): return TypeNumber.SEQUENCE
|
|
if isinstance(v, set) or isinstance(v, frozenset): return TypeNumber.SET
|
|
if isinstance(v, dict): return TypeNumber.DICTIONARY
|
|
|
|
if isinstance(v, Embedded): return TypeNumber.EMBEDDED
|
|
|
|
try:
|
|
i = iter(v)
|
|
except TypeError:
|
|
i = None
|
|
if i is None:
|
|
raise ValueError('Invalid Preserves value in type_number')
|
|
else:
|
|
return TypeNumber.SEQUENCE
|
|
|
|
def cmp(a, b):
|
|
return _cmp(preserve(a), preserve(b))
|
|
|
|
def lt(a, b):
|
|
return cmp(a, b) < 0
|
|
|
|
def le(a, b):
|
|
return cmp(a, b) <= 0
|
|
|
|
def eq(a, b):
|
|
return _eq(preserve(a), preserve(b))
|
|
|
|
key = cmp_to_key(cmp)
|
|
_key = key
|
|
|
|
_sorted = sorted
|
|
def sorted(iterable, *, key=lambda x: x, reverse=False):
|
|
return _sorted(iterable, key=lambda x: _key(key(x)), reverse=reverse)
|
|
|
|
def sorted_items(d):
|
|
return sorted(d.items(), key=_item_key)
|
|
|
|
def _eq_sequences(aa, bb):
|
|
aa = list(aa)
|
|
bb = list(bb)
|
|
n = len(aa)
|
|
if len(bb) != n: return False
|
|
for i in range(n):
|
|
if not _eq(aa[i], bb[i]): return False
|
|
return True
|
|
|
|
def _item_key(item):
|
|
return item[0]
|
|
|
|
def _eq(a, b):
|
|
ta = type_number(a)
|
|
tb = type_number(b)
|
|
if ta != tb: return False
|
|
|
|
if ta == TypeNumber.EMBEDDED:
|
|
return ta.embeddedValue == tb.embeddedValue
|
|
|
|
if ta == TypeNumber.RECORD:
|
|
return _eq(a.key, b.key) and _eq_sequences(a.fields, b.fields)
|
|
|
|
if ta == TypeNumber.SEQUENCE:
|
|
return _eq_sequences(a, b)
|
|
|
|
if ta == TypeNumber.SET:
|
|
return _eq_sequences(sorted(a), sorted(b))
|
|
|
|
if ta == TypeNumber.DICTIONARY:
|
|
return _eq_sequences(sorted_items(a), sorted_items(b))
|
|
|
|
return a == b
|
|
|
|
def _simplecmp(a, b):
|
|
return (a > b) - (a < b)
|
|
|
|
def _cmp_sequences(aa, bb):
|
|
aa = list(aa)
|
|
bb = list(bb)
|
|
n = min(len(aa), len(bb))
|
|
for i in range(n):
|
|
v = _cmp(aa[i], bb[i])
|
|
if v != 0: return v
|
|
return len(aa) - len(bb)
|
|
|
|
def _cmp(a, b):
|
|
ta = type_number(a)
|
|
tb = type_number(b)
|
|
if ta.value < tb.value: return -1
|
|
if tb.value < ta.value: return 1
|
|
|
|
if ta == TypeNumber.EMBEDDED:
|
|
return _simplecmp(ta.embeddedValue, tb.embeddedValue)
|
|
|
|
if ta == TypeNumber.RECORD:
|
|
v = _cmp(a.key, b.key)
|
|
if v != 0: return v
|
|
return _cmp_sequences(a.fields, b.fields)
|
|
|
|
if ta == TypeNumber.SEQUENCE:
|
|
return _cmp_sequences(a, b)
|
|
|
|
if ta == TypeNumber.SET:
|
|
return _cmp_sequences(sorted(a), sorted(b))
|
|
|
|
if ta == TypeNumber.DICTIONARY:
|
|
return _cmp_sequences(sorted_items(a), sorted_items(b))
|
|
|
|
return _simplecmp(a, b)
|