forked from syndicate-lang/preserves
Split out modules
This commit is contained in:
parent
cf192b634c
commit
123b6222ca
|
@ -1,7 +1,6 @@
|
|||
from .preserves import Float, Symbol, Record, ImmutableDict
|
||||
from .repr import Float, Symbol, Record, ImmutableDict
|
||||
from .repr import Annotated, is_annotated, strip_annotations, annotate
|
||||
|
||||
from .preserves import DecodeError, EncodeError, ShortPacket
|
||||
from .error import DecodeError, EncodeError, ShortPacket
|
||||
|
||||
from .preserves import Decoder, Encoder, decode, decode_with_annotations, encode
|
||||
|
||||
from .preserves import Annotated, is_annotated, strip_annotations, annotate
|
||||
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
import numbers
|
||||
import struct
|
||||
|
||||
from .repr import *
|
||||
from .error import *
|
||||
from .compat import basestring_, ord_
|
||||
|
||||
class Codec(object): pass
|
||||
|
||||
class Decoder(Codec):
|
||||
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.packet = packet
|
||||
self.index = 0
|
||||
self.include_annotations = include_annotations
|
||||
self.decode_embedded = decode_embedded
|
||||
|
||||
def extend(self, data):
|
||||
self.packet = self.packet[self.index:] + data
|
||||
self.index = 0
|
||||
|
||||
def nextbyte(self):
|
||||
if self.index >= len(self.packet):
|
||||
raise ShortPacket('Short packet')
|
||||
self.index = self.index + 1
|
||||
return ord_(self.packet[self.index - 1])
|
||||
|
||||
def nextbytes(self, n):
|
||||
start = self.index
|
||||
end = start + n
|
||||
if end > len(self.packet):
|
||||
raise ShortPacket('Short packet')
|
||||
self.index = end
|
||||
return self.packet[start : end]
|
||||
|
||||
def varint(self):
|
||||
v = self.nextbyte()
|
||||
if v < 128:
|
||||
return v
|
||||
else:
|
||||
return self.varint() * 128 + (v - 128)
|
||||
|
||||
def peekend(self):
|
||||
matched = (self.nextbyte() == 0x84)
|
||||
if not matched:
|
||||
self.index = self.index - 1
|
||||
return matched
|
||||
|
||||
def nextvalues(self):
|
||||
result = []
|
||||
while not self.peekend():
|
||||
result.append(self.next())
|
||||
return result
|
||||
|
||||
def nextint(self, n):
|
||||
if n == 0: return 0
|
||||
acc = self.nextbyte()
|
||||
if acc & 0x80: acc = acc - 256
|
||||
for _i in range(n - 1):
|
||||
acc = (acc << 8) | self.nextbyte()
|
||||
return acc
|
||||
|
||||
def wrap(self, v):
|
||||
return Annotated(v) if self.include_annotations else v
|
||||
|
||||
def unshift_annotation(self, a, v):
|
||||
if self.include_annotations:
|
||||
v.annotations.insert(0, a)
|
||||
return v
|
||||
|
||||
def next(self):
|
||||
tag = self.nextbyte()
|
||||
if tag == 0x80: return self.wrap(False)
|
||||
if tag == 0x81: return self.wrap(True)
|
||||
if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
|
||||
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
|
||||
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
|
||||
if tag == 0x85:
|
||||
a = self.next()
|
||||
v = self.next()
|
||||
return self.unshift_annotation(a, v)
|
||||
if tag == 0x86:
|
||||
if self.decode_embedded is None:
|
||||
raise DecodeError('No decode_embedded function supplied')
|
||||
return self.wrap(self.decode_embedded(self.next()))
|
||||
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
|
||||
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
|
||||
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
|
||||
if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8'))
|
||||
if tag == 0xb2: return self.wrap(self.nextbytes(self.varint()))
|
||||
if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8')))
|
||||
if tag == 0xb4:
|
||||
vs = self.nextvalues()
|
||||
if not vs: raise DecodeError('Too few elements in encoded record')
|
||||
return self.wrap(Record(vs[0], vs[1:]))
|
||||
if tag == 0xb5: return self.wrap(tuple(self.nextvalues()))
|
||||
if tag == 0xb6: return self.wrap(frozenset(self.nextvalues()))
|
||||
if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues()))
|
||||
raise DecodeError('Invalid tag: ' + hex(tag))
|
||||
|
||||
def try_next(self):
|
||||
start = self.index
|
||||
try:
|
||||
return self.next()
|
||||
except ShortPacket:
|
||||
self.index = start
|
||||
return None
|
||||
|
||||
def decode(bs, **kwargs):
|
||||
return Decoder(packet=bs, **kwargs).next()
|
||||
|
||||
def decode_with_annotations(bs, **kwargs):
|
||||
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
|
||||
|
||||
class Encoder(Codec):
|
||||
def __init__(self, encode_embedded=id):
|
||||
super(Encoder, self).__init__()
|
||||
self.buffer = bytearray()
|
||||
self.encode_embedded = encode_embedded
|
||||
|
||||
def contents(self):
|
||||
return bytes(self.buffer)
|
||||
|
||||
def varint(self, v):
|
||||
if v < 128:
|
||||
self.buffer.append(v)
|
||||
else:
|
||||
self.buffer.append((v % 128) + 128)
|
||||
self.varint(v // 128)
|
||||
|
||||
def encodeint(self, v):
|
||||
bitcount = (~v if v < 0 else v).bit_length() + 1
|
||||
bytecount = (bitcount + 7) // 8
|
||||
if bytecount <= 16:
|
||||
self.buffer.append(0xa0 + bytecount - 1)
|
||||
else:
|
||||
self.buffer.append(0xb0)
|
||||
self.varint(bytecount)
|
||||
def enc(n,x):
|
||||
if n > 0:
|
||||
enc(n-1, x >> 8)
|
||||
self.buffer.append(x & 255)
|
||||
enc(bytecount, v)
|
||||
|
||||
def encodevalues(self, tag, items):
|
||||
self.buffer.append(0xb0 + tag)
|
||||
for i in items: self.append(i)
|
||||
self.buffer.append(0x84)
|
||||
|
||||
def encodebytes(self, tag, bs):
|
||||
self.buffer.append(0xb0 + tag)
|
||||
self.varint(len(bs))
|
||||
self.buffer.extend(bs)
|
||||
|
||||
def append(self, v):
|
||||
if hasattr(v, '__preserve_on__'):
|
||||
v.__preserve_on__(self)
|
||||
elif v is False:
|
||||
self.buffer.append(0x80)
|
||||
elif v is True:
|
||||
self.buffer.append(0x81)
|
||||
elif isinstance(v, float):
|
||||
self.buffer.append(0x83)
|
||||
self.buffer.extend(struct.pack('>d', v))
|
||||
elif isinstance(v, numbers.Number):
|
||||
if v >= -3 and v <= 12:
|
||||
self.buffer.append(0x90 + (v if v >= 0 else v + 16))
|
||||
else:
|
||||
self.encodeint(v)
|
||||
elif isinstance(v, bytes):
|
||||
self.encodebytes(2, v)
|
||||
elif isinstance(v, basestring_):
|
||||
self.encodebytes(1, v.encode('utf-8'))
|
||||
elif isinstance(v, list):
|
||||
self.encodevalues(5, v)
|
||||
elif isinstance(v, tuple):
|
||||
self.encodevalues(5, v)
|
||||
elif isinstance(v, set):
|
||||
self.encodevalues(6, v)
|
||||
elif isinstance(v, frozenset):
|
||||
self.encodevalues(6, v)
|
||||
elif isinstance(v, dict):
|
||||
self.encodevalues(7, list(dict_kvs(v)))
|
||||
else:
|
||||
try:
|
||||
i = iter(v)
|
||||
except TypeError:
|
||||
self.buffer.append(0x86)
|
||||
self.append(self.encode_embedded(v))
|
||||
return
|
||||
self.encodevalues(5, i)
|
||||
|
||||
def encode(v, **kwargs):
|
||||
e = Encoder(**kwargs)
|
||||
e.append(v)
|
||||
return e.contents()
|
|
@ -0,0 +1,9 @@
|
|||
try:
|
||||
basestring_ = basestring
|
||||
except NameError:
|
||||
basestring_ = str
|
||||
|
||||
if isinstance(chr(123), bytes):
|
||||
ord_ = ord
|
||||
else:
|
||||
ord_ = lambda x: x
|
|
@ -0,0 +1,3 @@
|
|||
class DecodeError(ValueError): pass
|
||||
class EncodeError(ValueError): pass
|
||||
class ShortPacket(DecodeError): pass
|
|
@ -1,16 +1,7 @@
|
|||
import sys
|
||||
import numbers
|
||||
import struct
|
||||
|
||||
try:
|
||||
basestring
|
||||
except NameError:
|
||||
basestring = str
|
||||
|
||||
if isinstance(chr(123), bytes):
|
||||
_ord = ord
|
||||
else:
|
||||
_ord = lambda x: x
|
||||
from .error import DecodeError
|
||||
|
||||
class Float(object):
|
||||
def __init__(self, value):
|
||||
|
@ -35,10 +26,7 @@ class Float(object):
|
|||
|
||||
class Symbol(object):
|
||||
def __init__(self, name):
|
||||
if isinstance(name, Symbol):
|
||||
self.name = name.name
|
||||
else:
|
||||
self.name = name
|
||||
self.name = name.name if isinstance(name, Symbol) else name
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Symbol) and self.name == other.name
|
||||
|
@ -185,12 +173,6 @@ def dict_kvs(d):
|
|||
yield k
|
||||
yield d[k]
|
||||
|
||||
class DecodeError(ValueError): pass
|
||||
class EncodeError(ValueError): pass
|
||||
class ShortPacket(DecodeError): pass
|
||||
|
||||
class Codec(object): pass
|
||||
|
||||
inf = float('inf')
|
||||
|
||||
class Annotated(object):
|
||||
|
@ -258,191 +240,3 @@ def annotate(v, *anns):
|
|||
for a in anns:
|
||||
v.annotations.append(a)
|
||||
return v
|
||||
|
||||
class Decoder(Codec):
|
||||
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.packet = packet
|
||||
self.index = 0
|
||||
self.include_annotations = include_annotations
|
||||
self.decode_embedded = decode_embedded
|
||||
|
||||
def extend(self, data):
|
||||
self.packet = self.packet[self.index:] + data
|
||||
self.index = 0
|
||||
|
||||
def nextbyte(self):
|
||||
if self.index >= len(self.packet):
|
||||
raise ShortPacket('Short packet')
|
||||
self.index = self.index + 1
|
||||
return _ord(self.packet[self.index - 1])
|
||||
|
||||
def nextbytes(self, n):
|
||||
start = self.index
|
||||
end = start + n
|
||||
if end > len(self.packet):
|
||||
raise ShortPacket('Short packet')
|
||||
self.index = end
|
||||
return self.packet[start : end]
|
||||
|
||||
def varint(self):
|
||||
v = self.nextbyte()
|
||||
if v < 128:
|
||||
return v
|
||||
else:
|
||||
return self.varint() * 128 + (v - 128)
|
||||
|
||||
def peekend(self):
|
||||
matched = (self.nextbyte() == 0x84)
|
||||
if not matched:
|
||||
self.index = self.index - 1
|
||||
return matched
|
||||
|
||||
def nextvalues(self):
|
||||
result = []
|
||||
while not self.peekend():
|
||||
result.append(self.next())
|
||||
return result
|
||||
|
||||
def nextint(self, n):
|
||||
if n == 0: return 0
|
||||
acc = self.nextbyte()
|
||||
if acc & 0x80: acc = acc - 256
|
||||
for _i in range(n - 1):
|
||||
acc = (acc << 8) | self.nextbyte()
|
||||
return acc
|
||||
|
||||
def wrap(self, v):
|
||||
return Annotated(v) if self.include_annotations else v
|
||||
|
||||
def unshift_annotation(self, a, v):
|
||||
if self.include_annotations:
|
||||
v.annotations.insert(0, a)
|
||||
return v
|
||||
|
||||
def next(self):
|
||||
tag = self.nextbyte()
|
||||
if tag == 0x80: return self.wrap(False)
|
||||
if tag == 0x81: return self.wrap(True)
|
||||
if tag == 0x82: return self.wrap(Float(struct.unpack('>f', self.nextbytes(4))[0]))
|
||||
if tag == 0x83: return self.wrap(struct.unpack('>d', self.nextbytes(8))[0])
|
||||
if tag == 0x84: raise DecodeError('Unexpected end-of-stream marker')
|
||||
if tag == 0x85:
|
||||
a = self.next()
|
||||
v = self.next()
|
||||
return self.unshift_annotation(a, v)
|
||||
if tag == 0x86:
|
||||
if self.decode_embedded is None:
|
||||
raise DecodeError('No decode_embedded function supplied')
|
||||
return self.wrap(self.decode_embedded(self.next()))
|
||||
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
|
||||
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
|
||||
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
|
||||
if tag == 0xb1: return self.wrap(self.nextbytes(self.varint()).decode('utf-8'))
|
||||
if tag == 0xb2: return self.wrap(self.nextbytes(self.varint()))
|
||||
if tag == 0xb3: return self.wrap(Symbol(self.nextbytes(self.varint()).decode('utf-8')))
|
||||
if tag == 0xb4:
|
||||
vs = self.nextvalues()
|
||||
if not vs: raise DecodeError('Too few elements in encoded record')
|
||||
return self.wrap(Record(vs[0], vs[1:]))
|
||||
if tag == 0xb5: return self.wrap(tuple(self.nextvalues()))
|
||||
if tag == 0xb6: return self.wrap(frozenset(self.nextvalues()))
|
||||
if tag == 0xb7: return self.wrap(ImmutableDict.from_kvs(self.nextvalues()))
|
||||
raise DecodeError('Invalid tag: ' + hex(tag))
|
||||
|
||||
def try_next(self):
|
||||
start = self.index
|
||||
try:
|
||||
return self.next()
|
||||
except ShortPacket:
|
||||
self.index = start
|
||||
return None
|
||||
|
||||
def decode(bs, **kwargs):
|
||||
return Decoder(packet=bs, **kwargs).next()
|
||||
|
||||
def decode_with_annotations(bs, **kwargs):
|
||||
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
|
||||
|
||||
class Encoder(Codec):
|
||||
def __init__(self, encode_embedded=id):
|
||||
super(Encoder, self).__init__()
|
||||
self.buffer = bytearray()
|
||||
self.encode_embedded = encode_embedded
|
||||
|
||||
def contents(self):
|
||||
return bytes(self.buffer)
|
||||
|
||||
def varint(self, v):
|
||||
if v < 128:
|
||||
self.buffer.append(v)
|
||||
else:
|
||||
self.buffer.append((v % 128) + 128)
|
||||
self.varint(v // 128)
|
||||
|
||||
def encodeint(self, v):
|
||||
bitcount = (~v if v < 0 else v).bit_length() + 1
|
||||
bytecount = (bitcount + 7) // 8
|
||||
if bytecount <= 16:
|
||||
self.buffer.append(0xa0 + bytecount - 1)
|
||||
else:
|
||||
self.buffer.append(0xb0)
|
||||
self.varint(bytecount)
|
||||
def enc(n,x):
|
||||
if n > 0:
|
||||
enc(n-1, x >> 8)
|
||||
self.buffer.append(x & 255)
|
||||
enc(bytecount, v)
|
||||
|
||||
def encodevalues(self, tag, items):
|
||||
self.buffer.append(0xb0 + tag)
|
||||
for i in items: self.append(i)
|
||||
self.buffer.append(0x84)
|
||||
|
||||
def encodebytes(self, tag, bs):
|
||||
self.buffer.append(0xb0 + tag)
|
||||
self.varint(len(bs))
|
||||
self.buffer.extend(bs)
|
||||
|
||||
def append(self, v):
|
||||
if hasattr(v, '__preserve_on__'):
|
||||
v.__preserve_on__(self)
|
||||
elif v is False:
|
||||
self.buffer.append(0x80)
|
||||
elif v is True:
|
||||
self.buffer.append(0x81)
|
||||
elif isinstance(v, float):
|
||||
self.buffer.append(0x83)
|
||||
self.buffer.extend(struct.pack('>d', v))
|
||||
elif isinstance(v, numbers.Number):
|
||||
if v >= -3 and v <= 12:
|
||||
self.buffer.append(0x90 + (v if v >= 0 else v + 16))
|
||||
else:
|
||||
self.encodeint(v)
|
||||
elif isinstance(v, bytes):
|
||||
self.encodebytes(2, v)
|
||||
elif isinstance(v, basestring):
|
||||
self.encodebytes(1, v.encode('utf-8'))
|
||||
elif isinstance(v, list):
|
||||
self.encodevalues(5, v)
|
||||
elif isinstance(v, tuple):
|
||||
self.encodevalues(5, v)
|
||||
elif isinstance(v, set):
|
||||
self.encodevalues(6, v)
|
||||
elif isinstance(v, frozenset):
|
||||
self.encodevalues(6, v)
|
||||
elif isinstance(v, dict):
|
||||
self.encodevalues(7, list(dict_kvs(v)))
|
||||
else:
|
||||
try:
|
||||
i = iter(v)
|
||||
except TypeError:
|
||||
self.buffer.append(0x86)
|
||||
self.append(self.encode_embedded(v))
|
||||
return
|
||||
self.encodevalues(5, i)
|
||||
|
||||
def encode(v, **kwargs):
|
||||
e = Encoder(**kwargs)
|
||||
e.append(v)
|
||||
return e.contents()
|
|
@ -1,6 +1,9 @@
|
|||
from .preserves import *
|
||||
import unittest
|
||||
import numbers
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from . import *
|
||||
from .compat import basestring_, ord_
|
||||
|
||||
if isinstance(chr(123), bytes):
|
||||
def _byte(x):
|
||||
|
@ -18,7 +21,7 @@ def _buf(*args):
|
|||
for chunk in args:
|
||||
if isinstance(chunk, bytes):
|
||||
result.append(chunk)
|
||||
elif isinstance(chunk, basestring):
|
||||
elif isinstance(chunk, basestring_):
|
||||
result.append(chunk.encode('utf-8'))
|
||||
elif isinstance(chunk, numbers.Number):
|
||||
result.append(_byte(chunk))
|
||||
|
@ -165,9 +168,8 @@ class CodecTests(unittest.TestCase):
|
|||
a2 = A(1)
|
||||
self.assertNotEqual(_e(a1), _e(a2))
|
||||
self.assertEqual(_e(a1), _e(a1))
|
||||
from .preserves import _ord
|
||||
self.assertEqual(_ord(_e(a1)[0]), 0x86)
|
||||
self.assertEqual(_ord(_e(a2)[0]), 0x86)
|
||||
self.assertEqual(ord_(_e(a1)[0]), 0x86)
|
||||
self.assertEqual(ord_(_e(a2)[0]), 0x86)
|
||||
|
||||
def test_decode_embedded_absent(self):
|
||||
with self.assertRaises(DecodeError):
|
||||
|
|
Loading…
Reference in New Issue