Embedded wrapper; preserve() function

This commit is contained in:
Tony Garnock-Jones 2021-08-17 14:06:52 -04:00
parent a5065955ca
commit 82c66ec1c4
6 changed files with 92 additions and 57 deletions

View File

@ -1,4 +1,4 @@
from .values import Float, Symbol, Record, ImmutableDict
from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve
from .values import Annotated, is_annotated, strip_annotations, annotate
from .error import DecodeError, EncodeError, ShortPacket

View File

@ -8,7 +8,7 @@ from .compat import basestring_, ord_
class BinaryCodec(object): pass
class Decoder(BinaryCodec):
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
def __init__(self, packet=b'', include_annotations=False, decode_embedded=lambda x: x):
super(Decoder, self).__init__()
self.packet = packet
self.index = 0
@ -82,7 +82,7 @@ class Decoder(BinaryCodec):
if tag == 0x86:
if self.decode_embedded is None:
raise DecodeError('No decode_embedded function supplied')
return self.wrap(self.decode_embedded(self.next()))
return self.wrap(Embedded(self.decode_embedded(self.next())))
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
@ -106,6 +106,15 @@ class Decoder(BinaryCodec):
self.index = start
return None
def __iter__(self):
return self
def __next__(self):
v = self.try_next()
if v is None:
raise StopIteration
return v
def decode(bs, **kwargs):
return Decoder(packet=bs, **kwargs).next()
@ -113,10 +122,15 @@ def decode_with_annotations(bs, **kwargs):
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
class Encoder(BinaryCodec):
def __init__(self, encode_embedded=id):
def __init__(self, encode_embedded=lambda x: x):
super(Encoder, self).__init__()
self.buffer = bytearray()
self.encode_embedded = encode_embedded
self._encode_embedded = encode_embedded
def encode_embedded(self, v):
if self._encode_embedded is None:
raise EncodeError('No encode_embedded function supplied')
return self._encode_embedded(v)
def contents(self):
return bytes(self.buffer)
@ -153,9 +167,7 @@ class Encoder(BinaryCodec):
self.buffer.extend(bs)
def append(self, v):
while hasattr(v, '__preserve__'):
v = v.__preserve__()
v = preserve(v)
if hasattr(v, '__preserve_write_binary__'):
v.__preserve_write_binary__(self)
elif v is False:
@ -188,9 +200,7 @@ class Encoder(BinaryCodec):
try:
i = iter(v)
except TypeError:
self.buffer.append(0x86)
self.append(self.encode_embedded(v))
return
raise TypeError('Cannot preserves-encode: ' + repr(v))
self.encodevalues(5, i)
def encode(v, **kwargs):

View File

@ -1,4 +1,4 @@
from .preserves import *
from . import *
import pathlib
import keyword
@ -66,7 +66,8 @@ class SchemaObject:
if k == SYMBOL and isinstance(v, Symbol): return v
return None
if p.key == EMBEDDED:
return v ## TODO: reconsider representation of embedded values?
if not isinstance(v, Embedded): return None
return v.embeddedValue
if p.key == LIT:
if v == p[0]: return ()
return None
@ -283,7 +284,7 @@ def encode(p, v):
if p.key == ATOM:
return v
if p.key == EMBEDDED:
return v ## TODO: reconsider representation of embedded values?
return Embedded(v)
if p.key == LIT:
return p[0]
if p.key == SEQOF:
@ -417,7 +418,10 @@ def load_schema_file(filename):
# a decorator
def extend(cls):
return lambda f: setattr(cls, f.__name__, f)
def extender(f):
setattr(cls, f.__name__, f)
return f
return extender
__metaschema_filename = pathlib.Path(__file__).parent / '../../../schema/schema.bin'
meta = load_schema_file(__metaschema_filename).schema

View File

@ -166,23 +166,23 @@ class BinaryCodecTests(unittest.TestCase):
class A:
def __init__(self, a):
self.a = a
a1 = A(1)
a2 = A(1)
self.assertNotEqual(_e(a1), _e(a2))
self.assertEqual(_e(a1), _e(a1))
self.assertEqual(ord_(_e(a1)[0]), 0x86)
self.assertEqual(ord_(_e(a2)[0]), 0x86)
a1 = Embedded(A(1))
a2 = Embedded(A(1))
self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id))
self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id))
self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86)
self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86)
def test_decode_embedded_absent(self):
with self.assertRaises(DecodeError):
decode(b'\x86\xa0\xff')
decode(b'\x86\xa0\xff', decode_embedded=None)
def test_encode_embedded(self):
objects = []
def enc(p):
objects.append(p)
return len(objects) - 1
self.assertEqual(encode([object(), object()], encode_embedded = enc),
self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc),
b'\xb5\x86\x90\x86\x91\x84')
def test_decode_embedded(self):
@ -190,15 +190,15 @@ class BinaryCodecTests(unittest.TestCase):
def dec(v):
return objects[v]
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
(123, 234))
(Embedded(123), Embedded(234)))
def load_binary_samples():
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f:
return Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next()
return Decoder(f.read(), include_annotations=True, decode_embedded=lambda x: x).next()
def load_text_samples():
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f:
return Parser(f.read(), include_annotations=True, parse_embedded=Embedded).next()
return Parser(f.read(), include_annotations=True, parse_embedded=lambda x: x).next()
class TextCodecTests(unittest.TestCase):
def test_samples_bin_eq_txt(self):
@ -208,8 +208,8 @@ class TextCodecTests(unittest.TestCase):
def test_txt_roundtrip(self):
b = load_binary_samples()
s = stringify(b, format_embedded=Embedded.value)
self.assertEqual(parse(s, include_annotations=True, parse_embedded=Embedded), b)
s = stringify(b, format_embedded=lambda x: x)
self.assertEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b)
def add_method(d, tName, fn):
if hasattr(fn, 'func_name'):
@ -277,22 +277,6 @@ def install_exn_test(d, tName, bs, check_proc):
self.fail('did not fail as expected')
add_method(d, tName, test_exn)
class Embedded:
def __init__(self, v):
self.v = strip_annotations(v)
@staticmethod
def value(i):
return i.v
def __eq__(self, other):
other = _unwrap(other)
if other.__class__ is self.__class__:
return self.v == other.v
def __hash__(self):
return hash(self.v)
class CommonTestSuite(unittest.TestCase):
TestCases = Record.makeConstructor('TestCases', 'cases')
@ -323,13 +307,13 @@ class CommonTestSuite(unittest.TestCase):
raise Exception('Unsupported test kind', t.key)
def DS(self, bs):
return decode(bs, decode_embedded=Embedded)
return decode(bs, decode_embedded=lambda x: x)
def D(self, bs):
return decode_with_annotations(bs, decode_embedded=Embedded)
return decode_with_annotations(bs, decode_embedded=lambda x: x)
def E(self, v):
return encode(v, encode_embedded=Embedded.value)
return encode(v, encode_embedded=lambda x: x)
class RecordTests(unittest.TestCase):
def test_getters(self):

View File

@ -10,7 +10,7 @@ from .binary import Decoder
class TextCodec(object): pass
class Parser(TextCodec):
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=None):
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=lambda x: x):
super(Parser, self).__init__()
self.input_buffer = input_buffer
self.index = 0
@ -267,7 +267,9 @@ class Parser(TextCodec):
raise DecodeError('ByteString must follow #=')
return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next())
if c == '!':
return self.wrap(self.parse_embedded(self.next()))
if self.parse_embedded is None:
raise DecodeError('No parse_embedded function supplied')
return self.wrap(Embedded(self.parse_embedded(self.next())))
raise DecodeError('Invalid # syntax')
if c == '<':
self.skip()
@ -294,6 +296,15 @@ class Parser(TextCodec):
self.index = start
return None
def __iter__(self):
return self
def __next__(self):
v = self.try_next()
if v is None:
raise StopIteration
return v
def parse(bs, **kwargs):
return Parser(input_buffer=bs, **kwargs).next()
@ -301,10 +312,15 @@ def parse_with_annotations(bs, **kwargs):
return Parser(input_buffer=bs, include_annotations=True, **kwargs).next()
class Formatter(TextCodec):
def __init__(self, format_embedded=None):
def __init__(self, format_embedded=lambda x: x):
super(Formatter, self).__init__()
self.chunks = []
self.format_embedded = format_embedded
self._format_embedded = format_embedded
def format_embedded(self, v):
if self._format_embedded is None:
raise EncodeError('No format_embedded function supplied')
return self._format_embedded(v)
def contents(self):
return u''.join(self.chunks)
@ -330,9 +346,7 @@ class Formatter(TextCodec):
self.chunks.append(closer)
def append(self, v):
while hasattr(v, '__preserve__'):
v = v.__preserve__()
v = preserve(v)
if hasattr(v, '__preserve_write_text__'):
v.__preserve_write_text__(self)
elif v is False:
@ -375,9 +389,7 @@ class Formatter(TextCodec):
try:
i = iter(v)
except TypeError:
self.chunks.append('#!')
self.append(self.format_embedded(v))
return
raise TypeError('Cannot preserves-format: ' + repr(v))
self.write_seq('[', ']', i)
def stringify(v, **kwargs):

View File

@ -4,6 +4,11 @@ import struct
from .error import DecodeError
def preserve(v):
while hasattr(v, '__preserve__'):
v = v.__preserve__()
return v
class Float(object):
def __init__(self, value):
self.value = value
@ -276,3 +281,23 @@ def _unwrap(x):
return x.item
else:
return x
class Embedded:
def __init__(self, value):
self.embeddedValue = value
def __eq__(self, other):
other = _unwrap(other)
if other.__class__ is self.__class__:
return self.embeddedValue == other.embeddedValue
def __hash__(self):
return hash(self.embeddedValue)
def __preserve_write_binary__(self, encoder):
encoder.buffer.append(0x86)
encoder.append(encoder.encode_embedded(self.embeddedValue))
def __preserve_write_text__(self, formatter):
formatter.chunks.append('#!')
formatter.append(formatter.format_embedded(self.embeddedValue))