import sys import numbers import struct try: basestring except NameError: basestring = str if isinstance(chr(123), bytes): _ord = ord else: _ord = lambda x: x class Float(object): def __init__(self, value): self.value = value def __eq__(self, other): if other.__class__ is self.__class__: return self.value == other.value def __repr__(self): return 'Float(' + repr(self.value) + ')' def __preserve_on__(self, encoder): encoder.leadbyte(0, 0, 2) encoder.buffer.extend(struct.pack('>f', self.value)) class Symbol(object): def __init__(self, name): self.name = name def __eq__(self, other): return isinstance(other, Symbol) and self.name == other.name def __hash__(self): return hash(self.name) def __repr__(self): return '#' + self.name def __preserve_on__(self, encoder): bs = self.name.encode('utf-8') encoder.header(1, 3, len(bs)) encoder.buffer.extend(bs) class Record(object): def __init__(self, key, fields): self.key = key self.fields = tuple(fields) self.__hash = None def __eq__(self, other): return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields) def __hash__(self): if self.__hash is None: self.__hash = hash((self.key, self.fields)) return self.__hash def __repr__(self): return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')' def __preserve_on__(self, encoder): try: index = encoder.shortForms.index(self.key) except ValueError: index = None if index is None: encoder.header(2, 3, len(self.fields) + 1) encoder.append(self.key) else: encoder.header(2, index, len(self.fields)) for f in self.fields: encoder.append(f) # Blub blub blub class ImmutableDict(dict): def __init__(self, *args, **kwargs): if hasattr(self, '__hash'): raise TypeError('Immutable') super(ImmutableDict, self).__init__(*args, **kwargs) self.__hash = None def __delitem__(self, key): raise TypeError('Immutable') def __setitem__(self, key, val): raise TypeError('Immutable') def clear(self): raise TypeError('Immutable') def pop(self, k, d=None): raise TypeError('Immutable') def popitem(self): raise TypeError('Immutable') def setdefault(self, k, d=None): raise TypeError('Immutable') def update(self, e, **f): raise TypeError('Immutable') def __hash__(self): if self.__hash is None: h = 0 for k in self: h = ((h << 5) ^ (hash(k) << 2) ^ hash(self[k])) & sys.maxsize self.__hash = h return self.__hash @staticmethod def from_kvs(kvs): i = iter(kvs) result = ImmutableDict() result_proxy = super(ImmutableDict, result) try: while True: k = next(i) v = next(i) result_proxy.__setitem__(k, v) except StopIteration: pass return result def dict_kvs(d): for k in d: yield k yield d[k] class DecodeError(ValueError): pass class EncodeError(ValueError): pass class Codec(object): def __init__(self): self.shortForms = [Symbol(u'discard'), Symbol(u'capture'), Symbol(u'observe')] def set_shortform(self, index, v): if index >= 0 and index < 3: self.shortForms[index] = v else: raise ValueError('Invalid short form index %r' % (index,)) class Stream(object): def __init__(self, iterator): self._iterator = iterator def __preserve_on__(self, encoder): arg = (self.major << 2) | self.minor encoder.leadbyte(0, 2, arg) self._emit(encoder) encoder.leadbyte(0, 3, arg) def _emit(self, encoder): raise NotImplementedError('Should be implemented in subclasses') class ValueStream(Stream): major = 3 def _emit(self, encoder): for v in self._iterator: encoder.append(v) class SequenceStream(ValueStream): minor = 0 class SetStream(ValueStream): minor = 1 class DictStream(ValueStream): minor = 2 def _emit(self, encoder): for (k, v) in self._iterator: encoder.append(k) encoder.append(v) class BinaryStream(Stream): major = 1 minor = 2 def _emit(self, encoder): for chunk in self._iterator: if not isinstance(chunk, bytes): raise EncodeError('Illegal chunk in BinaryStream %r' % (chunk,)) encoder.append(chunk) class StringStream(BinaryStream): minor = 1 class SymbolStream(BinaryStream): minor = 3 class Decoder(Codec): def __init__(self, packet): super(Decoder, self).__init__() self.packet = packet self.index = 0 def peekbyte(self): if self.index < len(self.packet): return _ord(self.packet[self.index]) else: raise DecodeError('Short packet') def advance(self, count=1): start = self.index self.index = self.index + count return start def nextbyte(self): val = self.peekbyte() self.advance() return val def wirelength(self, arg): if arg < 15: return arg return self.varint() def varint(self): v = self.nextbyte() if v < 128: return v else: return self.varint() * 128 + (v - 128) def nextbytes(self, n): start = self.advance(n) return self.packet[start : self.index] def nextvalues(self, n): result = [] for i in range(n): result.append(self.next()) return result def peekop(self): b = self.peekbyte() major = b >> 6 minor = (b >> 4) & 3 arg = b & 15 return (major, minor, arg) def nextop(self): op = self.peekop() self.advance() return op def peekend(self, arg): return self.peekop() == (0, 3, arg) def binarystream(self, arg, minor): result = [] while not self.peekend(arg): chunk = self.next() if isinstance(chunk, bytes): result.append(chunk) else: raise DecodeError('Unexpected non-binary chunk') return self.decodebinary(minor, b''.join(result)) def valuestream(self, arg, minor, decoder): result = [] while not self.peekend(arg): result.append(self.next()) return decoder(minor, result) def decodeint(self, bs): if len(bs) == 0: return 0 acc = _ord(bs[0]) if acc & 0x80: acc = acc - 256 for b in bs[1:]: acc = (acc << 8) | _ord(b) return acc def decodebinary(self, minor, bs): if minor == 0: return self.decodeint(bs) if minor == 1: return bs.decode('utf-8') if minor == 2: return bs if minor == 3: return Symbol(bs.decode('utf-8')) def decoderecord(self, minor, vs): if minor == 3: if not vs: raise DecodeError('Too few elements in encoded record') return Record(vs[0], vs[1:]) else: return Record(self.shortForms[minor], vs) def decodecollection(self, minor, vs): if minor == 0: return tuple(vs) if minor == 1: return frozenset(vs) if minor == 2: return ImmutableDict.from_kvs(vs) if minor == 3: raise DecodeError('Invalid collection type') def next(self): (major, minor, arg) = self.nextop() if major == 0: if minor == 0: if arg == 0: return False if arg == 1: return True if arg == 2: return Float(struct.unpack('>f', self.nextbytes(4))[0]) if arg == 3: return struct.unpack('>d', self.nextbytes(8))[0] raise DecodeError('Invalid format A encoding') elif minor == 1: return arg - 16 if arg > 12 else arg elif minor == 2: t = arg >> 2 n = arg & 3 if t == 0: raise DecodeError('Invalid format C start byte') if t == 1: return self.binarystream(arg, n) if t == 2: return self.valuestream(arg, n, self.decoderecord) if t == 3: return self.valuestream(arg, n, self.decodecollection) else: # minor == 3 raise DecodeError('Unexpected format C end byte') elif major == 1: return self.decodebinary(minor, self.nextbytes(self.wirelength(arg))) elif major == 2: return self.decoderecord(minor, self.nextvalues(self.wirelength(arg))) else: # major == 3 return self.decodecollection(minor, self.nextvalues(self.wirelength(arg))) class Encoder(Codec): def __init__(self): super(Encoder, self).__init__() self.buffer = bytearray() def contents(self): return bytes(self.buffer) def varint(self, v): if v < 128: self.buffer.append(v) else: self.buffer.append((v % 128) + 128) self.varint(v // 128) def leadbyte(self, major, minor, arg): self.buffer.append(((major & 3) << 6) | ((minor & 3) << 4) | (arg & 15)) def header(self, major, minor, wirelength): if wirelength < 15: self.leadbyte(major, minor, wirelength) else: self.leadbyte(major, minor, 15) self.varint(wirelength) def encodeint(self, v): bitcount = (~v if v < 0 else v).bit_length() + 1 bytecount = (bitcount + 7) // 8 self.header(1, 0, bytecount) def enc(n,x): if n > 0: enc(n-1, x >> 8) self.buffer.append(x & 255) enc(bytecount, v) def encodecollection(self, minor, items): self.header(3, minor, len(items)) for i in items: self.append(i) def append(self, v): if hasattr(v, '__preserve_on__'): v.__preserve_on__(self) elif v is False: self.leadbyte(0, 0, 0) elif v is True: self.leadbyte(0, 0, 1) elif isinstance(v, float): self.leadbyte(0, 0, 3) self.buffer.extend(struct.pack('>d', v)) elif isinstance(v, numbers.Number): if v >= -3 and v <= 12: self.leadbyte(0, 1, v if v >= 0 else v + 16) else: self.encodeint(v) elif isinstance(v, bytes): self.header(1, 2, len(v)) self.buffer.extend(v) elif isinstance(v, basestring): bs = v.encode('utf-8') self.header(1, 1, len(bs)) self.buffer.extend(bs) elif isinstance(v, list): self.encodecollection(0, v) elif isinstance(v, tuple): self.encodecollection(0, v) elif isinstance(v, set): self.encodecollection(1, v) elif isinstance(v, frozenset): self.encodecollection(1, v) elif isinstance(v, dict): self.encodecollection(2, list(dict_kvs(v))) else: try: i = iter(v) except TypeError: raise EncodeError('Cannot encode %r' % (v,)) self.encodestream(3, 0, i)