446 lines
14 KiB
Python
446 lines
14 KiB
Python
import sys
|
|
import numbers
|
|
import struct
|
|
|
|
try:
|
|
basestring
|
|
except NameError:
|
|
basestring = str
|
|
|
|
if isinstance(chr(123), bytes):
|
|
_ord = ord
|
|
else:
|
|
_ord = lambda x: x
|
|
|
|
class Float(object):
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __eq__(self, other):
|
|
if other.__class__ is self.__class__:
|
|
return self.value == other.value
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.value)
|
|
|
|
def __repr__(self):
|
|
return 'Float(' + repr(self.value) + ')'
|
|
|
|
def __preserve_on__(self, encoder):
|
|
encoder.buffer.append(0x82)
|
|
encoder.buffer.extend(struct.pack('>f', self.value))
|
|
|
|
class Symbol(object):
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, Symbol) and self.name == other.name
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
def __repr__(self):
|
|
return '#' + self.name
|
|
|
|
def __preserve_on__(self, encoder):
|
|
bs = self.name.encode('utf-8')
|
|
encoder.buffer.append(0xb3)
|
|
encoder.varint(len(bs))
|
|
encoder.buffer.extend(bs)
|
|
|
|
class Record(object):
|
|
def __init__(self, key, fields):
|
|
self.key = key
|
|
self.fields = tuple(fields)
|
|
self.__hash = None
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
if self.__hash is None:
|
|
self.__hash = hash((self.key, self.fields))
|
|
return self.__hash
|
|
|
|
def __repr__(self):
|
|
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
|
|
|
|
def __preserve_on__(self, encoder):
|
|
encoder.buffer.append(0xb4)
|
|
encoder.append(self.key)
|
|
for f in self.fields:
|
|
encoder.append(f)
|
|
encoder.buffer.append(0x84)
|
|
|
|
def __getitem__(self, index):
|
|
return self.fields[index]
|
|
|
|
@staticmethod
|
|
def makeConstructor(labelSymbolText, fieldNames):
|
|
return Record.makeBasicConstructor(Symbol(labelSymbolText), fieldNames)
|
|
|
|
@staticmethod
|
|
def makeBasicConstructor(label, fieldNames):
|
|
if type(fieldNames) == str:
|
|
fieldNames = fieldNames.split()
|
|
arity = len(fieldNames)
|
|
def ctor(*fields):
|
|
if len(fields) != arity:
|
|
raise Exception("Record: cannot instantiate %r expecting %d fields with %d fields"%(
|
|
label,
|
|
arity,
|
|
len(fields)))
|
|
return Record(label, fields)
|
|
ctor.constructorInfo = RecordConstructorInfo(label, arity)
|
|
ctor.isClassOf = lambda v: \
|
|
isinstance(v, Record) and v.key == label and len(v.fields) == arity
|
|
def ensureClassOf(v):
|
|
if not ctor.isClassOf(v):
|
|
raise TypeError("Record: expected %r/%d, got %r" % (label, arity, v))
|
|
return v
|
|
ctor.ensureClassOf = ensureClassOf
|
|
for fieldIndex in range(len(fieldNames)):
|
|
fieldName = fieldNames[fieldIndex]
|
|
# Stupid python scoping bites again
|
|
def getter(fieldIndex):
|
|
return lambda v: ensureClassOf(v)[fieldIndex]
|
|
setattr(ctor, '_' + fieldName, getter(fieldIndex))
|
|
return ctor
|
|
|
|
class RecordConstructorInfo(object):
|
|
def __init__(self, key, arity):
|
|
self.key = key
|
|
self.arity = arity
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, RecordConstructorInfo) and \
|
|
(self.key, self.arity) == (other.key, other.arity)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
if self.__hash is None:
|
|
self.__hash = hash((self.key, self.arity))
|
|
return self.__hash
|
|
|
|
def __repr__(self):
|
|
return str(self.key) + '/' + str(self.arity)
|
|
|
|
# Blub blub blub
|
|
class ImmutableDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
if hasattr(self, '__hash'): raise TypeError('Immutable')
|
|
super(ImmutableDict, self).__init__(*args, **kwargs)
|
|
self.__hash = None
|
|
|
|
def __delitem__(self, key): raise TypeError('Immutable')
|
|
def __setitem__(self, key, val): raise TypeError('Immutable')
|
|
def clear(self): raise TypeError('Immutable')
|
|
def pop(self, k, d=None): raise TypeError('Immutable')
|
|
def popitem(self): raise TypeError('Immutable')
|
|
def setdefault(self, k, d=None): raise TypeError('Immutable')
|
|
def update(self, e, **f): raise TypeError('Immutable')
|
|
|
|
def __hash__(self):
|
|
if self.__hash is None:
|
|
h = 0
|
|
for k in self:
|
|
h = ((h << 5) ^ (hash(k) << 2) ^ hash(self[k])) & sys.maxsize
|
|
self.__hash = h
|
|
return self.__hash
|
|
|
|
@staticmethod
|
|
def from_kvs(kvs):
|
|
i = iter(kvs)
|
|
result = ImmutableDict()
|
|
result_proxy = super(ImmutableDict, result)
|
|
try:
|
|
while True:
|
|
k = next(i)
|
|
try:
|
|
v = next(i)
|
|
except StopIteration:
|
|
raise DecodeError("Missing dictionary value")
|
|
result_proxy.__setitem__(k, v)
|
|
except StopIteration:
|
|
pass
|
|
return result
|
|
|
|
def dict_kvs(d):
|
|
for k in d:
|
|
yield k
|
|
yield d[k]
|
|
|
|
class DecodeError(ValueError): pass
|
|
class EncodeError(ValueError): pass
|
|
class ShortPacket(DecodeError): pass
|
|
|
|
class Codec(object): pass
|
|
|
|
inf = float('inf')
|
|
|
|
class Annotated(object):
|
|
def __init__(self, item):
|
|
self.annotations = []
|
|
self.item = item
|
|
|
|
def __preserve_on__(self, encoder):
|
|
for a in self.annotations:
|
|
encoder.buffer.append(0x85)
|
|
encoder.append(a)
|
|
encoder.append(self.item)
|
|
|
|
def strip(self, depth=inf):
|
|
return strip_annotations(self, depth)
|
|
|
|
def peel(self):
|
|
return strip_annotations(self, 1)
|
|
|
|
def __eq__(self, other):
|
|
if other.__class__ is self.__class__:
|
|
return self.item == other.item
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def __hash__(self):
|
|
return hash(self.item)
|
|
|
|
def __repr__(self):
|
|
return ' '.join(list('@' + repr(a) for a in self.annotations) + [repr(self.item)])
|
|
|
|
def is_annotated(v):
|
|
return isinstance(v, Annotated)
|
|
|
|
def strip_annotations(v, depth=inf):
|
|
if depth == 0: return v
|
|
if not is_annotated(v): return v
|
|
|
|
next_depth = depth - 1
|
|
def walk(v):
|
|
return strip_annotations(v, next_depth)
|
|
|
|
v = v.item
|
|
if isinstance(v, Record):
|
|
return Record(strip_annotations(v.key, depth), tuple(walk(f) for f in v.fields))
|
|
elif isinstance(v, list):
|
|
return tuple(walk(f) for f in v)
|
|
elif isinstance(v, tuple):
|
|
return tuple(walk(f) for f in v)
|
|
elif isinstance(v, set):
|
|
return frozenset(walk(f) for f in v)
|
|
elif isinstance(v, frozenset):
|
|
return frozenset(walk(f) for f in v)
|
|
elif isinstance(v, dict):
|
|
return ImmutableDict.from_kvs(walk(f) for f in dict_kvs(v))
|
|
elif is_annotated(v):
|
|
raise ValueError('Improper annotation structure')
|
|
else:
|
|
return v
|
|
|
|
def annotate(v, *anns):
|
|
if not is_annotated(v):
|
|
v = Annotated(v)
|
|
for a in anns:
|
|
v.annotations.append(a)
|
|
return v
|
|
|
|
class Decoder(Codec):
|
|
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
|
|
super(Decoder, self).__init__()
|
|
self.packet = packet
|
|
self.index = 0
|
|
self.include_annotations = include_annotations
|
|
self.decode_embedded = decode_embedded
|
|
|
|
def extend(self, data):
|
|
self.packet = self.packet[self.index:] + data
|
|
self.index = 0
|
|
|
|
def nextbyte(self):
|
|
if self.index >= len(self.packet):
|
|
raise ShortPacket('Short packet')
|
|
self.index = self.index + 1
|
|
return _ord(self.packet[self.index - 1])
|
|
|
|
def nextbytes(self, n):
|
|
start = self.index
|
|
end = start + n
|
|
if end > len(self.packet):
|
|
raise ShortPacket('Short packet')
|
|
self.index = end
|
|
return self.packet[start : end]
|
|
|
|
def varint(self):
|
|
v = self.nextbyte()
|
|
if v < 128:
|
|
return v
|
|
else:
|
|
return self.varint() * 128 + (v - 128)
|
|
|
|
def peekend(self):
|
|
matched = (self.nextbyte() == 0x84)
|
|
if not matched:
|
|
self.index = self.index - 1
|
|
return matched
|
|
|
|
def nextvalues(self):
|
|
result = []
|
|
while not self.peekend():
|
|
result.append(self.next())
|
|
return result
|
|
|
|
def nextint(self, n):
|
|
if n == 0: return 0
|
|
acc = self.nextbyte()
|
|
if acc & 0x80: acc = acc - 256
|
|
for _i in range(n - 1):
|
|
acc = (acc << 8) | self.nextbyte()
|
|
return acc
|
|
|
|
def wrap(self, v):
|
|
return Annotated(v) if self.include_annotations else v
|
|
|
|
def unshift_annotation(self, a, v):
|
|
if self.include_annotations:
|
|
v.annotations.insert(0, a)
|
|
return v
|
|
|
|
def next(self):
|
|
tag = self.nextbyte()
|
|
if tag == 0x80: return self.wrap(False)
|
|
if tag == 0x81: return self.wrap(True)
|
|
if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
|
|
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
|
|
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
|
|
if tag == 0x85:
|
|
a = self.next()
|
|
v = self.next()
|
|
return self.unshift_annotation(a, v)
|
|
if tag == 0x86:
|
|
if self.decode_embedded is None:
|
|
raise DecodeError('No decode_embedded function supplied')
|
|
return self.wrap(self.decode_embedded(self.next()))
|
|
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
|
|
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
|
|
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
|
|
if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8'))
|
|
if tag == 0xb2: return self.wrap(self.nextbytes(self.varint()))
|
|
if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8')))
|
|
if tag == 0xb4:
|
|
vs = self.nextvalues()
|
|
if not vs: raise DecodeError('Too few elements in encoded record')
|
|
return self.wrap(Record(vs[0], vs[1:]))
|
|
if tag == 0xb5: return self.wrap(tuple(self.nextvalues()))
|
|
if tag == 0xb6: return self.wrap(frozenset(self.nextvalues()))
|
|
if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues()))
|
|
raise DecodeError('Invalid tag: ' + hex(tag))
|
|
|
|
def try_next(self):
|
|
start = self.index
|
|
try:
|
|
return self.next()
|
|
except ShortPacket:
|
|
self.index = start
|
|
return None
|
|
|
|
def decode(bs, **kwargs):
|
|
return Decoder(packet=bs, **kwargs).next()
|
|
|
|
def decode_with_annotations(bs, **kwargs):
|
|
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
|
|
|
|
class Encoder(Codec):
|
|
def __init__(self, encode_embedded=id):
|
|
super(Encoder, self).__init__()
|
|
self.buffer = bytearray()
|
|
self.encode_embedded = encode_embedded
|
|
|
|
def contents(self):
|
|
return bytes(self.buffer)
|
|
|
|
def varint(self, v):
|
|
if v < 128:
|
|
self.buffer.append(v)
|
|
else:
|
|
self.buffer.append((v % 128) + 128)
|
|
self.varint(v // 128)
|
|
|
|
def encodeint(self, v):
|
|
bitcount = (~v if v < 0 else v).bit_length() + 1
|
|
bytecount = (bitcount + 7) // 8
|
|
if bytecount <= 16:
|
|
self.buffer.append(0xa0 + bytecount - 1)
|
|
else:
|
|
self.buffer.append(0xb0)
|
|
self.varint(bytecount)
|
|
def enc(n,x):
|
|
if n > 0:
|
|
enc(n-1, x >> 8)
|
|
self.buffer.append(x & 255)
|
|
enc(bytecount, v)
|
|
|
|
def encodevalues(self, tag, items):
|
|
self.buffer.append(0xb0 + tag)
|
|
for i in items: self.append(i)
|
|
self.buffer.append(0x84)
|
|
|
|
def encodebytes(self, tag, bs):
|
|
self.buffer.append(0xb0 + tag)
|
|
self.varint(len(bs))
|
|
self.buffer.extend(bs)
|
|
|
|
def append(self, v):
|
|
if hasattr(v, '__preserve_on__'):
|
|
v.__preserve_on__(self)
|
|
elif v is False:
|
|
self.buffer.append(0x80)
|
|
elif v is True:
|
|
self.buffer.append(0x81)
|
|
elif isinstance(v, float):
|
|
self.buffer.append(0x83)
|
|
self.buffer.extend(struct.pack('>d', v))
|
|
elif isinstance(v, numbers.Number):
|
|
if v >= -3 and v <= 12:
|
|
self.buffer.append(0x90 + (v if v >= 0 else v + 16))
|
|
else:
|
|
self.encodeint(v)
|
|
elif isinstance(v, bytes):
|
|
self.encodebytes(2, v)
|
|
elif isinstance(v, basestring):
|
|
self.encodebytes(1, v.encode('utf-8'))
|
|
elif isinstance(v, list):
|
|
self.encodevalues(5, v)
|
|
elif isinstance(v, tuple):
|
|
self.encodevalues(5, v)
|
|
elif isinstance(v, set):
|
|
self.encodevalues(6, v)
|
|
elif isinstance(v, frozenset):
|
|
self.encodevalues(6, v)
|
|
elif isinstance(v, dict):
|
|
self.encodevalues(7, list(dict_kvs(v)))
|
|
else:
|
|
try:
|
|
i = iter(v)
|
|
except TypeError:
|
|
self.buffer.append(0x86)
|
|
self.append(self.encode_embedded(v))
|
|
return
|
|
self.encodevalues(5, i)
|
|
|
|
def encode(v, **kwargs):
|
|
e = Encoder(**kwargs)
|
|
e.append(v)
|
|
return e.contents()
|