preserves/implementations/python/preserves/compare.py

154 lines
3.8 KiB
Python

import numbers
from enum import Enum
from functools import cmp_to_key
from .values import preserve, Float, Embedded, Record, Symbol, cmp_floats, _unwrap
from .compat import basestring_
class TypeNumber(Enum):
BOOL = 0
FLOAT = 1
DOUBLE = 2
SIGNED_INTEGER = 3
STRING = 4
BYTE_STRING = 5
SYMBOL = 6
RECORD = 7
SEQUENCE = 8
SET = 9
DICTIONARY = 10
EMBEDDED = 11
def type_number(v):
if hasattr(v, '__preserve__'):
raise ValueError('type_number expects Preserves value; use preserve()')
if isinstance(v, bool): return TypeNumber.BOOL
if isinstance(v, Float): return TypeNumber.FLOAT
if isinstance(v, float): return TypeNumber.DOUBLE
if isinstance(v, numbers.Number): return TypeNumber.SIGNED_INTEGER
if isinstance(v, basestring_): return TypeNumber.STRING
if isinstance(v, bytes): return TypeNumber.BYTE_STRING
if isinstance(v, Symbol): return TypeNumber.SYMBOL
if isinstance(v, Record): return TypeNumber.RECORD
if isinstance(v, list) or isinstance(v, tuple): return TypeNumber.SEQUENCE
if isinstance(v, set) or isinstance(v, frozenset): return TypeNumber.SET
if isinstance(v, dict): return TypeNumber.DICTIONARY
if isinstance(v, Embedded): return TypeNumber.EMBEDDED
try:
i = iter(v)
except TypeError:
i = None
if i is None:
raise ValueError('Invalid Preserves value in type_number')
else:
return TypeNumber.SEQUENCE
def cmp(a, b):
return _cmp(preserve(a), preserve(b))
def lt(a, b):
return cmp(a, b) < 0
def le(a, b):
return cmp(a, b) <= 0
def eq(a, b):
return _eq(preserve(a), preserve(b))
key = cmp_to_key(cmp)
_key = key
_sorted = sorted
def sorted(iterable, *, key=lambda x: x, reverse=False):
return _sorted(iterable, key=lambda x: _key(key(x)), reverse=reverse)
def sorted_items(d):
return sorted(d.items(), key=_item_key)
def _eq_sequences(aa, bb):
aa = list(aa)
bb = list(bb)
n = len(aa)
if len(bb) != n: return False
for i in range(n):
if not _eq(aa[i], bb[i]): return False
return True
def _item_key(item):
return item[0]
def _eq(a, b):
a = _unwrap(a)
b = _unwrap(b)
ta = type_number(a)
tb = type_number(b)
if ta != tb: return False
if ta == TypeNumber.DOUBLE:
return cmp_floats(a, b) == 0
if ta == TypeNumber.EMBEDDED:
return _eq(a.embeddedValue, b.embeddedValue)
if ta == TypeNumber.RECORD:
return _eq(a.key, b.key) and _eq_sequences(a.fields, b.fields)
if ta == TypeNumber.SEQUENCE:
return _eq_sequences(a, b)
if ta == TypeNumber.SET:
return _eq_sequences(sorted(a), sorted(b))
if ta == TypeNumber.DICTIONARY:
return _eq_sequences(sorted_items(a), sorted_items(b))
return a == b
def _simplecmp(a, b):
return (a > b) - (a < b)
def _cmp_sequences(aa, bb):
aa = list(aa)
bb = list(bb)
n = min(len(aa), len(bb))
for i in range(n):
v = _cmp(aa[i], bb[i])
if v != 0: return v
return len(aa) - len(bb)
def _cmp(a, b):
a = _unwrap(a)
b = _unwrap(b)
ta = type_number(a)
tb = type_number(b)
if ta.value < tb.value: return -1
if tb.value < ta.value: return 1
if ta == TypeNumber.DOUBLE:
return cmp_floats(a, b)
if ta == TypeNumber.EMBEDDED:
return _cmp(a.embeddedValue, b.embeddedValue)
if ta == TypeNumber.RECORD:
v = _cmp(a.key, b.key)
if v != 0: return v
return _cmp_sequences(a.fields, b.fields)
if ta == TypeNumber.SEQUENCE:
return _cmp_sequences(a, b)
if ta == TypeNumber.SET:
return _cmp_sequences(sorted(a), sorted(b))
if ta == TypeNumber.DICTIONARY:
return _cmp_sequences(sorted_items(a), sorted_items(b))
return _simplecmp(a, b)