366 lines
11 KiB
Python
366 lines
11 KiB
Python
import re
|
|
import sys
|
|
import struct
|
|
import math
|
|
|
|
from .error import DecodeError
|
|
|
|
def preserve(v):
|
|
while hasattr(v, '__preserve__'):
|
|
v = v.__preserve__()
|
|
return v
|
|
|
|
def float_to_int(v):
|
|
return struct.unpack('>Q', struct.pack('>d', v))[0]
|
|
|
|
def cmp_floats(a, b):
|
|
a = float_to_int(a)
|
|
b = float_to_int(b)
|
|
if a & 0x8000000000000000: a = a ^ 0x7fffffffffffffff
|
|
if b & 0x8000000000000000: b = b ^ 0x7fffffffffffffff
|
|
return a - b
|
|
|
|
class Float(object):
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __eq__(self, other):
|
|
other = _unwrap(other)
|
|
if other.__class__ is self.__class__:
|
|
return cmp_floats(self.value, other.value) == 0
|
|
|
|
def __lt__(self, other):
|
|
other = _unwrap(other)
|
|
if other.__class__ is self.__class__:
|
|
return cmp_floats(self.value, other.value) < 0
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.value)
|
|
|
|
def __repr__(self):
|
|
return 'Float(' + repr(self.value) + ')'
|
|
|
|
def _to_bytes(self):
|
|
if math.isnan(self.value) or math.isinf(self.value):
|
|
dbs = struct.pack('>d', self.value)
|
|
vd = struct.unpack('>Q', dbs)[0]
|
|
sign = vd >> 63
|
|
payload = (vd >> 29) & 0x007fffff
|
|
vf = (sign << 31) | 0x7f800000 | payload
|
|
return struct.pack('>I', vf)
|
|
else:
|
|
return struct.pack('>f', self.value)
|
|
|
|
def __preserve_write_binary__(self, encoder):
|
|
encoder.buffer.append(0x82)
|
|
encoder.buffer.extend(self._to_bytes())
|
|
|
|
def __preserve_write_text__(self, formatter):
|
|
if math.isnan(self.value) or math.isinf(self.value):
|
|
formatter.chunks.append('#xf"' + self._to_bytes().hex() + '"')
|
|
else:
|
|
formatter.chunks.append(repr(self.value) + 'f')
|
|
|
|
@staticmethod
|
|
def from_bytes(bs):
|
|
vf = struct.unpack('>I', bs)[0]
|
|
if (vf & 0x7f800000) == 0x7f800000:
|
|
# NaN or inf. Preserve quiet/signalling bit by manually expanding to double-precision.
|
|
sign = vf >> 31
|
|
payload = vf & 0x007fffff
|
|
dbs = struct.pack('>Q', (sign << 63) | 0x7ff0000000000000 | (payload << 29))
|
|
return Float(struct.unpack('>d', dbs)[0])
|
|
else:
|
|
return Float(struct.unpack('>f', bs)[0])
|
|
|
|
# FIXME: This regular expression is conservatively correct, but Anglo-chauvinistic.
|
|
RAW_SYMBOL_RE = re.compile(r'^[-a-zA-Z0-9~!$%^&*?_=+/.]+$')
|
|
|
|
class Symbol(object):
|
|
def __init__(self, name):
|
|
self.name = name.name if isinstance(name, Symbol) else name
|
|
|
|
def __eq__(self, other):
|
|
other = _unwrap(other)
|
|
return isinstance(other, Symbol) and self.name == other.name
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __lt__(self, other):
|
|
return self.name < other.name
|
|
|
|
def __le__(self, other):
|
|
return self.name <= other.name
|
|
|
|
def __gt__(self, other):
|
|
return self.name > other.name
|
|
|
|
def __ge__(self, other):
|
|
return self.name >= other.name
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
def __repr__(self):
|
|
return '#' + self.name
|
|
|
|
def __preserve_write_binary__(self, encoder):
|
|
bs = self.name.encode('utf-8')
|
|
encoder.buffer.append(0xb3)
|
|
encoder.varint(len(bs))
|
|
encoder.buffer.extend(bs)
|
|
|
|
def __preserve_write_text__(self, formatter):
|
|
if RAW_SYMBOL_RE.match(self.name):
|
|
formatter.chunks.append(self.name)
|
|
else:
|
|
formatter.chunks.append('|')
|
|
for c in self.name:
|
|
if c == '|': formatter.chunks.append('\\|')
|
|
else: formatter.write_stringlike_char(c)
|
|
formatter.chunks.append('|')
|
|
|
|
class Record(object):
|
|
def __init__(self, key, fields):
|
|
self.key = key
|
|
self.fields = tuple(fields)
|
|
self.__hash = None
|
|
|
|
def __eq__(self, other):
|
|
other = _unwrap(other)
|
|
return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
if self.__hash is None:
|
|
self.__hash = hash((self.key, self.fields))
|
|
return self.__hash
|
|
|
|
def __repr__(self):
|
|
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
|
|
|
|
def __preserve_write_binary__(self, encoder):
|
|
encoder.buffer.append(0xb4)
|
|
encoder.append(self.key)
|
|
for f in self.fields:
|
|
encoder.append(f)
|
|
encoder.buffer.append(0x84)
|
|
|
|
def __preserve_write_text__(self, formatter):
|
|
formatter.chunks.append('<')
|
|
formatter.append(self.key)
|
|
for f in self.fields:
|
|
formatter.chunks.append(' ')
|
|
formatter.append(f)
|
|
formatter.chunks.append('>')
|
|
|
|
def __getitem__(self, index):
|
|
return self.fields[index]
|
|
|
|
@staticmethod
|
|
def makeConstructor(labelSymbolText, fieldNames):
|
|
return Record.makeBasicConstructor(Symbol(labelSymbolText), fieldNames)
|
|
|
|
@staticmethod
|
|
def makeBasicConstructor(label, fieldNames):
|
|
if type(fieldNames) == str:
|
|
fieldNames = fieldNames.split()
|
|
arity = len(fieldNames)
|
|
def ctor(*fields):
|
|
if len(fields) != arity:
|
|
raise Exception("Record: cannot instantiate %r expecting %d fields with %d fields"%(
|
|
label,
|
|
arity,
|
|
len(fields)))
|
|
return Record(label, fields)
|
|
ctor.constructorInfo = RecordConstructorInfo(label, arity)
|
|
ctor.isClassOf = lambda v: \
|
|
isinstance(v, Record) and v.key == label and len(v.fields) == arity
|
|
def ensureClassOf(v):
|
|
if not ctor.isClassOf(v):
|
|
raise TypeError("Record: expected %r/%d, got %r" % (label, arity, v))
|
|
return v
|
|
ctor.ensureClassOf = ensureClassOf
|
|
for fieldIndex in range(len(fieldNames)):
|
|
fieldName = fieldNames[fieldIndex]
|
|
# Stupid python scoping bites again
|
|
def getter(fieldIndex):
|
|
return lambda v: ensureClassOf(v)[fieldIndex]
|
|
setattr(ctor, '_' + fieldName, getter(fieldIndex))
|
|
return ctor
|
|
|
|
class RecordConstructorInfo(object):
|
|
def __init__(self, key, arity):
|
|
self.key = key
|
|
self.arity = arity
|
|
|
|
def __eq__(self, other):
|
|
other = _unwrap(other)
|
|
return isinstance(other, RecordConstructorInfo) and \
|
|
(self.key, self.arity) == (other.key, other.arity)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
if self.__hash is None:
|
|
self.__hash = hash((self.key, self.arity))
|
|
return self.__hash
|
|
|
|
def __repr__(self):
|
|
return str(self.key) + '/' + str(self.arity)
|
|
|
|
# Blub blub blub
|
|
class ImmutableDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
if hasattr(self, '__hash'): raise TypeError('Immutable')
|
|
super(ImmutableDict, self).__init__(*args, **kwargs)
|
|
self.__hash = None
|
|
|
|
def __delitem__(self, key): raise TypeError('Immutable')
|
|
def __setitem__(self, key, val): raise TypeError('Immutable')
|
|
def clear(self): raise TypeError('Immutable')
|
|
def pop(self, k, d=None): raise TypeError('Immutable')
|
|
def popitem(self): raise TypeError('Immutable')
|
|
def setdefault(self, k, d=None): raise TypeError('Immutable')
|
|
def update(self, e, **f): raise TypeError('Immutable')
|
|
|
|
def __hash__(self):
|
|
if self.__hash is None:
|
|
h = 0
|
|
for k in self:
|
|
h = ((h << 5) ^ (hash(k) << 2) ^ hash(self[k])) & sys.maxsize
|
|
self.__hash = h
|
|
return self.__hash
|
|
|
|
@staticmethod
|
|
def from_kvs(kvs):
|
|
i = iter(kvs)
|
|
result = ImmutableDict()
|
|
result_proxy = super(ImmutableDict, result)
|
|
try:
|
|
while True:
|
|
k = next(i)
|
|
try:
|
|
v = next(i)
|
|
except StopIteration:
|
|
raise DecodeError("Missing dictionary value")
|
|
result_proxy.__setitem__(k, v)
|
|
except StopIteration:
|
|
pass
|
|
return result
|
|
|
|
def dict_kvs(d):
|
|
for k in d:
|
|
yield k
|
|
yield d[k]
|
|
|
|
inf = float('inf')
|
|
|
|
class Annotated(object):
|
|
def __init__(self, item):
|
|
self.annotations = []
|
|
self.item = item
|
|
|
|
def __preserve_write_binary__(self, encoder):
|
|
for a in self.annotations:
|
|
encoder.buffer.append(0x85)
|
|
encoder.append(a)
|
|
encoder.append(self.item)
|
|
|
|
def __preserve_write_text__(self, formatter):
|
|
for a in self.annotations:
|
|
formatter.chunks.append('@')
|
|
formatter.append(a)
|
|
formatter.chunks.append(' ')
|
|
formatter.append(self.item)
|
|
|
|
def strip(self, depth=inf):
|
|
return strip_annotations(self, depth)
|
|
|
|
def peel(self):
|
|
return strip_annotations(self, 1)
|
|
|
|
def __eq__(self, other):
|
|
return self.item == _unwrap(other)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.item)
|
|
|
|
def __repr__(self):
|
|
return ' '.join(list('@' + repr(a) for a in self.annotations) + [repr(self.item)])
|
|
|
|
def is_annotated(v):
|
|
return isinstance(v, Annotated)
|
|
|
|
def strip_annotations(v, depth=inf):
|
|
if depth == 0: return v
|
|
if not is_annotated(v): return v
|
|
|
|
next_depth = depth - 1
|
|
def walk(v):
|
|
return strip_annotations(v, next_depth)
|
|
|
|
v = v.item
|
|
if isinstance(v, Record):
|
|
return Record(strip_annotations(v.key, depth), tuple(walk(f) for f in v.fields))
|
|
elif isinstance(v, list):
|
|
return tuple(walk(f) for f in v)
|
|
elif isinstance(v, tuple):
|
|
return tuple(walk(f) for f in v)
|
|
elif isinstance(v, set):
|
|
return frozenset(walk(f) for f in v)
|
|
elif isinstance(v, frozenset):
|
|
return frozenset(walk(f) for f in v)
|
|
elif isinstance(v, dict):
|
|
return ImmutableDict.from_kvs(walk(f) for f in dict_kvs(v))
|
|
elif is_annotated(v):
|
|
raise ValueError('Improper annotation structure')
|
|
else:
|
|
return v
|
|
|
|
def annotate(v, *anns):
|
|
if not is_annotated(v):
|
|
v = Annotated(v)
|
|
for a in anns:
|
|
v.annotations.append(a)
|
|
return v
|
|
|
|
def _unwrap(x):
|
|
if is_annotated(x):
|
|
return x.item
|
|
else:
|
|
return x
|
|
|
|
class Embedded:
|
|
def __init__(self, value):
|
|
self.embeddedValue = value
|
|
|
|
def __eq__(self, other):
|
|
other = _unwrap(other)
|
|
if other.__class__ is self.__class__:
|
|
return self.embeddedValue == other.embeddedValue
|
|
|
|
def __hash__(self):
|
|
return hash(self.embeddedValue)
|
|
|
|
def __repr__(self):
|
|
return '#!%r' % (self.embeddedValue,)
|
|
|
|
def __preserve_write_binary__(self, encoder):
|
|
encoder.buffer.append(0x86)
|
|
encoder.append(encoder.encode_embedded(self.embeddedValue))
|
|
|
|
def __preserve_write_text__(self, formatter):
|
|
formatter.chunks.append('#!')
|
|
formatter.append(formatter.format_embedded(self.embeddedValue))
|