forked from syndicate-lang/preserves
Embedded wrapper; preserve() function
This commit is contained in:
parent
a5065955ca
commit
82c66ec1c4
|
@ -1,4 +1,4 @@
|
|||
from .values import Float, Symbol, Record, ImmutableDict
|
||||
from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve
|
||||
from .values import Annotated, is_annotated, strip_annotations, annotate
|
||||
|
||||
from .error import DecodeError, EncodeError, ShortPacket
|
||||
|
|
|
@ -8,7 +8,7 @@ from .compat import basestring_, ord_
|
|||
class BinaryCodec(object): pass
|
||||
|
||||
class Decoder(BinaryCodec):
|
||||
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
|
||||
def __init__(self, packet=b'', include_annotations=False, decode_embedded=lambda x: x):
|
||||
super(Decoder, self).__init__()
|
||||
self.packet = packet
|
||||
self.index = 0
|
||||
|
@ -82,7 +82,7 @@ class Decoder(BinaryCodec):
|
|||
if tag == 0x86:
|
||||
if self.decode_embedded is None:
|
||||
raise DecodeError('No decode_embedded function supplied')
|
||||
return self.wrap(self.decode_embedded(self.next()))
|
||||
return self.wrap(Embedded(self.decode_embedded(self.next())))
|
||||
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
|
||||
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
|
||||
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
|
||||
|
@ -106,6 +106,15 @@ class Decoder(BinaryCodec):
|
|||
self.index = start
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
v = self.try_next()
|
||||
if v is None:
|
||||
raise StopIteration
|
||||
return v
|
||||
|
||||
def decode(bs, **kwargs):
|
||||
return Decoder(packet=bs, **kwargs).next()
|
||||
|
||||
|
@ -113,10 +122,15 @@ def decode_with_annotations(bs, **kwargs):
|
|||
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
|
||||
|
||||
class Encoder(BinaryCodec):
|
||||
def __init__(self, encode_embedded=id):
|
||||
def __init__(self, encode_embedded=lambda x: x):
|
||||
super(Encoder, self).__init__()
|
||||
self.buffer = bytearray()
|
||||
self.encode_embedded = encode_embedded
|
||||
self._encode_embedded = encode_embedded
|
||||
|
||||
def encode_embedded(self, v):
|
||||
if self._encode_embedded is None:
|
||||
raise EncodeError('No encode_embedded function supplied')
|
||||
return self._encode_embedded(v)
|
||||
|
||||
def contents(self):
|
||||
return bytes(self.buffer)
|
||||
|
@ -153,9 +167,7 @@ class Encoder(BinaryCodec):
|
|||
self.buffer.extend(bs)
|
||||
|
||||
def append(self, v):
|
||||
while hasattr(v, '__preserve__'):
|
||||
v = v.__preserve__()
|
||||
|
||||
v = preserve(v)
|
||||
if hasattr(v, '__preserve_write_binary__'):
|
||||
v.__preserve_write_binary__(self)
|
||||
elif v is False:
|
||||
|
@ -188,9 +200,7 @@ class Encoder(BinaryCodec):
|
|||
try:
|
||||
i = iter(v)
|
||||
except TypeError:
|
||||
self.buffer.append(0x86)
|
||||
self.append(self.encode_embedded(v))
|
||||
return
|
||||
raise TypeError('Cannot preserves-encode: ' + repr(v))
|
||||
self.encodevalues(5, i)
|
||||
|
||||
def encode(v, **kwargs):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .preserves import *
|
||||
from . import *
|
||||
import pathlib
|
||||
import keyword
|
||||
|
||||
|
@ -66,7 +66,8 @@ class SchemaObject:
|
|||
if k == SYMBOL and isinstance(v, Symbol): return v
|
||||
return None
|
||||
if p.key == EMBEDDED:
|
||||
return v ## TODO: reconsider representation of embedded values?
|
||||
if not isinstance(v, Embedded): return None
|
||||
return v.embeddedValue
|
||||
if p.key == LIT:
|
||||
if v == p[0]: return ()
|
||||
return None
|
||||
|
@ -283,7 +284,7 @@ def encode(p, v):
|
|||
if p.key == ATOM:
|
||||
return v
|
||||
if p.key == EMBEDDED:
|
||||
return v ## TODO: reconsider representation of embedded values?
|
||||
return Embedded(v)
|
||||
if p.key == LIT:
|
||||
return p[0]
|
||||
if p.key == SEQOF:
|
||||
|
@ -417,7 +418,10 @@ def load_schema_file(filename):
|
|||
|
||||
# a decorator
|
||||
def extend(cls):
|
||||
return lambda f: setattr(cls, f.__name__, f)
|
||||
def extender(f):
|
||||
setattr(cls, f.__name__, f)
|
||||
return f
|
||||
return extender
|
||||
|
||||
__metaschema_filename = pathlib.Path(__file__).parent / '../../../schema/schema.bin'
|
||||
meta = load_schema_file(__metaschema_filename).schema
|
||||
|
|
|
@ -166,23 +166,23 @@ class BinaryCodecTests(unittest.TestCase):
|
|||
class A:
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
a1 = A(1)
|
||||
a2 = A(1)
|
||||
self.assertNotEqual(_e(a1), _e(a2))
|
||||
self.assertEqual(_e(a1), _e(a1))
|
||||
self.assertEqual(ord_(_e(a1)[0]), 0x86)
|
||||
self.assertEqual(ord_(_e(a2)[0]), 0x86)
|
||||
a1 = Embedded(A(1))
|
||||
a2 = Embedded(A(1))
|
||||
self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id))
|
||||
self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id))
|
||||
self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86)
|
||||
self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86)
|
||||
|
||||
def test_decode_embedded_absent(self):
|
||||
with self.assertRaises(DecodeError):
|
||||
decode(b'\x86\xa0\xff')
|
||||
decode(b'\x86\xa0\xff', decode_embedded=None)
|
||||
|
||||
def test_encode_embedded(self):
|
||||
objects = []
|
||||
def enc(p):
|
||||
objects.append(p)
|
||||
return len(objects) - 1
|
||||
self.assertEqual(encode([object(), object()], encode_embedded = enc),
|
||||
self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc),
|
||||
b'\xb5\x86\x90\x86\x91\x84')
|
||||
|
||||
def test_decode_embedded(self):
|
||||
|
@ -190,15 +190,15 @@ class BinaryCodecTests(unittest.TestCase):
|
|||
def dec(v):
|
||||
return objects[v]
|
||||
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
|
||||
(123, 234))
|
||||
(Embedded(123), Embedded(234)))
|
||||
|
||||
def load_binary_samples():
|
||||
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f:
|
||||
return Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next()
|
||||
return Decoder(f.read(), include_annotations=True, decode_embedded=lambda x: x).next()
|
||||
|
||||
def load_text_samples():
|
||||
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f:
|
||||
return Parser(f.read(), include_annotations=True, parse_embedded=Embedded).next()
|
||||
return Parser(f.read(), include_annotations=True, parse_embedded=lambda x: x).next()
|
||||
|
||||
class TextCodecTests(unittest.TestCase):
|
||||
def test_samples_bin_eq_txt(self):
|
||||
|
@ -208,8 +208,8 @@ class TextCodecTests(unittest.TestCase):
|
|||
|
||||
def test_txt_roundtrip(self):
|
||||
b = load_binary_samples()
|
||||
s = stringify(b, format_embedded=Embedded.value)
|
||||
self.assertEqual(parse(s, include_annotations=True, parse_embedded=Embedded), b)
|
||||
s = stringify(b, format_embedded=lambda x: x)
|
||||
self.assertEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b)
|
||||
|
||||
def add_method(d, tName, fn):
|
||||
if hasattr(fn, 'func_name'):
|
||||
|
@ -277,22 +277,6 @@ def install_exn_test(d, tName, bs, check_proc):
|
|||
self.fail('did not fail as expected')
|
||||
add_method(d, tName, test_exn)
|
||||
|
||||
class Embedded:
|
||||
def __init__(self, v):
|
||||
self.v = strip_annotations(v)
|
||||
|
||||
@staticmethod
|
||||
def value(i):
|
||||
return i.v
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
if other.__class__ is self.__class__:
|
||||
return self.v == other.v
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.v)
|
||||
|
||||
class CommonTestSuite(unittest.TestCase):
|
||||
TestCases = Record.makeConstructor('TestCases', 'cases')
|
||||
|
||||
|
@ -323,13 +307,13 @@ class CommonTestSuite(unittest.TestCase):
|
|||
raise Exception('Unsupported test kind', t.key)
|
||||
|
||||
def DS(self, bs):
|
||||
return decode(bs, decode_embedded=Embedded)
|
||||
return decode(bs, decode_embedded=lambda x: x)
|
||||
|
||||
def D(self, bs):
|
||||
return decode_with_annotations(bs, decode_embedded=Embedded)
|
||||
return decode_with_annotations(bs, decode_embedded=lambda x: x)
|
||||
|
||||
def E(self, v):
|
||||
return encode(v, encode_embedded=Embedded.value)
|
||||
return encode(v, encode_embedded=lambda x: x)
|
||||
|
||||
class RecordTests(unittest.TestCase):
|
||||
def test_getters(self):
|
||||
|
|
|
@ -10,7 +10,7 @@ from .binary import Decoder
|
|||
class TextCodec(object): pass
|
||||
|
||||
class Parser(TextCodec):
|
||||
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=None):
|
||||
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=lambda x: x):
|
||||
super(Parser, self).__init__()
|
||||
self.input_buffer = input_buffer
|
||||
self.index = 0
|
||||
|
@ -267,7 +267,9 @@ class Parser(TextCodec):
|
|||
raise DecodeError('ByteString must follow #=')
|
||||
return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next())
|
||||
if c == '!':
|
||||
return self.wrap(self.parse_embedded(self.next()))
|
||||
if self.parse_embedded is None:
|
||||
raise DecodeError('No parse_embedded function supplied')
|
||||
return self.wrap(Embedded(self.parse_embedded(self.next())))
|
||||
raise DecodeError('Invalid # syntax')
|
||||
if c == '<':
|
||||
self.skip()
|
||||
|
@ -294,6 +296,15 @@ class Parser(TextCodec):
|
|||
self.index = start
|
||||
return None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
v = self.try_next()
|
||||
if v is None:
|
||||
raise StopIteration
|
||||
return v
|
||||
|
||||
def parse(bs, **kwargs):
|
||||
return Parser(input_buffer=bs, **kwargs).next()
|
||||
|
||||
|
@ -301,10 +312,15 @@ def parse_with_annotations(bs, **kwargs):
|
|||
return Parser(input_buffer=bs, include_annotations=True, **kwargs).next()
|
||||
|
||||
class Formatter(TextCodec):
|
||||
def __init__(self, format_embedded=None):
|
||||
def __init__(self, format_embedded=lambda x: x):
|
||||
super(Formatter, self).__init__()
|
||||
self.chunks = []
|
||||
self.format_embedded = format_embedded
|
||||
self._format_embedded = format_embedded
|
||||
|
||||
def format_embedded(self, v):
|
||||
if self._format_embedded is None:
|
||||
raise EncodeError('No format_embedded function supplied')
|
||||
return self._format_embedded(v)
|
||||
|
||||
def contents(self):
|
||||
return u''.join(self.chunks)
|
||||
|
@ -330,9 +346,7 @@ class Formatter(TextCodec):
|
|||
self.chunks.append(closer)
|
||||
|
||||
def append(self, v):
|
||||
while hasattr(v, '__preserve__'):
|
||||
v = v.__preserve__()
|
||||
|
||||
v = preserve(v)
|
||||
if hasattr(v, '__preserve_write_text__'):
|
||||
v.__preserve_write_text__(self)
|
||||
elif v is False:
|
||||
|
@ -375,9 +389,7 @@ class Formatter(TextCodec):
|
|||
try:
|
||||
i = iter(v)
|
||||
except TypeError:
|
||||
self.chunks.append('#!')
|
||||
self.append(self.format_embedded(v))
|
||||
return
|
||||
raise TypeError('Cannot preserves-format: ' + repr(v))
|
||||
self.write_seq('[', ']', i)
|
||||
|
||||
def stringify(v, **kwargs):
|
||||
|
|
|
@ -4,6 +4,11 @@ import struct
|
|||
|
||||
from .error import DecodeError
|
||||
|
||||
def preserve(v):
|
||||
while hasattr(v, '__preserve__'):
|
||||
v = v.__preserve__()
|
||||
return v
|
||||
|
||||
class Float(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
@ -276,3 +281,23 @@ def _unwrap(x):
|
|||
return x.item
|
||||
else:
|
||||
return x
|
||||
|
||||
class Embedded:
|
||||
def __init__(self, value):
|
||||
self.embeddedValue = value
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
if other.__class__ is self.__class__:
|
||||
return self.embeddedValue == other.embeddedValue
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.embeddedValue)
|
||||
|
||||
def __preserve_write_binary__(self, encoder):
|
||||
encoder.buffer.append(0x86)
|
||||
encoder.append(encoder.encode_embedded(self.embeddedValue))
|
||||
|
||||
def __preserve_write_text__(self, formatter):
|
||||
formatter.chunks.append('#!')
|
||||
formatter.append(formatter.format_embedded(self.embeddedValue))
|
||||
|
|
Loading…
Reference in New Issue