Embedded wrapper; preserve() function

This commit is contained in:
Tony Garnock-Jones 2021-08-17 14:06:52 -04:00
parent a5065955ca
commit 82c66ec1c4
6 changed files with 92 additions and 57 deletions

View File

@ -1,4 +1,4 @@
from .values import Float, Symbol, Record, ImmutableDict from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve
from .values import Annotated, is_annotated, strip_annotations, annotate from .values import Annotated, is_annotated, strip_annotations, annotate
from .error import DecodeError, EncodeError, ShortPacket from .error import DecodeError, EncodeError, ShortPacket

View File

@ -8,7 +8,7 @@ from .compat import basestring_, ord_
class BinaryCodec(object): pass class BinaryCodec(object): pass
class Decoder(BinaryCodec): class Decoder(BinaryCodec):
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None): def __init__(self, packet=b'', include_annotations=False, decode_embedded=lambda x: x):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.packet = packet self.packet = packet
self.index = 0 self.index = 0
@ -82,7 +82,7 @@ class Decoder(BinaryCodec):
if tag == 0x86: if tag == 0x86:
if self.decode_embedded is None: if self.decode_embedded is None:
raise DecodeError('No decode_embedded function supplied') raise DecodeError('No decode_embedded function supplied')
return self.wrap(self.decode_embedded(self.next())) return self.wrap(Embedded(self.decode_embedded(self.next())))
if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90)) if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90))
if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1)) if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1))
if tag == 0xb0: return self.wrap(self.nextint(self.varint())) if tag == 0xb0: return self.wrap(self.nextint(self.varint()))
@ -106,6 +106,15 @@ class Decoder(BinaryCodec):
self.index = start self.index = start
return None return None
def __iter__(self):
return self
def __next__(self):
v = self.try_next()
if v is None:
raise StopIteration
return v
def decode(bs, **kwargs): def decode(bs, **kwargs):
return Decoder(packet=bs, **kwargs).next() return Decoder(packet=bs, **kwargs).next()
@ -113,10 +122,15 @@ def decode_with_annotations(bs, **kwargs):
return Decoder(packet=bs, include_annotations=True, **kwargs).next() return Decoder(packet=bs, include_annotations=True, **kwargs).next()
class Encoder(BinaryCodec): class Encoder(BinaryCodec):
def __init__(self, encode_embedded=id): def __init__(self, encode_embedded=lambda x: x):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.buffer = bytearray() self.buffer = bytearray()
self.encode_embedded = encode_embedded self._encode_embedded = encode_embedded
def encode_embedded(self, v):
if self._encode_embedded is None:
raise EncodeError('No encode_embedded function supplied')
return self._encode_embedded(v)
def contents(self): def contents(self):
return bytes(self.buffer) return bytes(self.buffer)
@ -153,9 +167,7 @@ class Encoder(BinaryCodec):
self.buffer.extend(bs) self.buffer.extend(bs)
def append(self, v): def append(self, v):
while hasattr(v, '__preserve__'): v = preserve(v)
v = v.__preserve__()
if hasattr(v, '__preserve_write_binary__'): if hasattr(v, '__preserve_write_binary__'):
v.__preserve_write_binary__(self) v.__preserve_write_binary__(self)
elif v is False: elif v is False:
@ -188,9 +200,7 @@ class Encoder(BinaryCodec):
try: try:
i = iter(v) i = iter(v)
except TypeError: except TypeError:
self.buffer.append(0x86) raise TypeError('Cannot preserves-encode: ' + repr(v))
self.append(self.encode_embedded(v))
return
self.encodevalues(5, i) self.encodevalues(5, i)
def encode(v, **kwargs): def encode(v, **kwargs):

View File

@ -1,4 +1,4 @@
from .preserves import * from . import *
import pathlib import pathlib
import keyword import keyword
@ -66,7 +66,8 @@ class SchemaObject:
if k == SYMBOL and isinstance(v, Symbol): return v if k == SYMBOL and isinstance(v, Symbol): return v
return None return None
if p.key == EMBEDDED: if p.key == EMBEDDED:
return v ## TODO: reconsider representation of embedded values? if not isinstance(v, Embedded): return None
return v.embeddedValue
if p.key == LIT: if p.key == LIT:
if v == p[0]: return () if v == p[0]: return ()
return None return None
@ -283,7 +284,7 @@ def encode(p, v):
if p.key == ATOM: if p.key == ATOM:
return v return v
if p.key == EMBEDDED: if p.key == EMBEDDED:
return v ## TODO: reconsider representation of embedded values? return Embedded(v)
if p.key == LIT: if p.key == LIT:
return p[0] return p[0]
if p.key == SEQOF: if p.key == SEQOF:
@ -417,7 +418,10 @@ def load_schema_file(filename):
# a decorator # a decorator
def extend(cls): def extend(cls):
return lambda f: setattr(cls, f.__name__, f) def extender(f):
setattr(cls, f.__name__, f)
return f
return extender
__metaschema_filename = pathlib.Path(__file__).parent / '../../../schema/schema.bin' __metaschema_filename = pathlib.Path(__file__).parent / '../../../schema/schema.bin'
meta = load_schema_file(__metaschema_filename).schema meta = load_schema_file(__metaschema_filename).schema

View File

@ -166,23 +166,23 @@ class BinaryCodecTests(unittest.TestCase):
class A: class A:
def __init__(self, a): def __init__(self, a):
self.a = a self.a = a
a1 = A(1) a1 = Embedded(A(1))
a2 = A(1) a2 = Embedded(A(1))
self.assertNotEqual(_e(a1), _e(a2)) self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id))
self.assertEqual(_e(a1), _e(a1)) self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id))
self.assertEqual(ord_(_e(a1)[0]), 0x86) self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86)
self.assertEqual(ord_(_e(a2)[0]), 0x86) self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86)
def test_decode_embedded_absent(self): def test_decode_embedded_absent(self):
with self.assertRaises(DecodeError): with self.assertRaises(DecodeError):
decode(b'\x86\xa0\xff') decode(b'\x86\xa0\xff', decode_embedded=None)
def test_encode_embedded(self): def test_encode_embedded(self):
objects = [] objects = []
def enc(p): def enc(p):
objects.append(p) objects.append(p)
return len(objects) - 1 return len(objects) - 1
self.assertEqual(encode([object(), object()], encode_embedded = enc), self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc),
b'\xb5\x86\x90\x86\x91\x84') b'\xb5\x86\x90\x86\x91\x84')
def test_decode_embedded(self): def test_decode_embedded(self):
@ -190,15 +190,15 @@ class BinaryCodecTests(unittest.TestCase):
def dec(v): def dec(v):
return objects[v] return objects[v]
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec), self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
(123, 234)) (Embedded(123), Embedded(234)))
def load_binary_samples(): def load_binary_samples():
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f: with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f:
return Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next() return Decoder(f.read(), include_annotations=True, decode_embedded=lambda x: x).next()
def load_text_samples(): def load_text_samples():
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f: with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f:
return Parser(f.read(), include_annotations=True, parse_embedded=Embedded).next() return Parser(f.read(), include_annotations=True, parse_embedded=lambda x: x).next()
class TextCodecTests(unittest.TestCase): class TextCodecTests(unittest.TestCase):
def test_samples_bin_eq_txt(self): def test_samples_bin_eq_txt(self):
@ -208,8 +208,8 @@ class TextCodecTests(unittest.TestCase):
def test_txt_roundtrip(self): def test_txt_roundtrip(self):
b = load_binary_samples() b = load_binary_samples()
s = stringify(b, format_embedded=Embedded.value) s = stringify(b, format_embedded=lambda x: x)
self.assertEqual(parse(s, include_annotations=True, parse_embedded=Embedded), b) self.assertEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b)
def add_method(d, tName, fn): def add_method(d, tName, fn):
if hasattr(fn, 'func_name'): if hasattr(fn, 'func_name'):
@ -277,22 +277,6 @@ def install_exn_test(d, tName, bs, check_proc):
self.fail('did not fail as expected') self.fail('did not fail as expected')
add_method(d, tName, test_exn) add_method(d, tName, test_exn)
class Embedded:
def __init__(self, v):
self.v = strip_annotations(v)
@staticmethod
def value(i):
return i.v
def __eq__(self, other):
other = _unwrap(other)
if other.__class__ is self.__class__:
return self.v == other.v
def __hash__(self):
return hash(self.v)
class CommonTestSuite(unittest.TestCase): class CommonTestSuite(unittest.TestCase):
TestCases = Record.makeConstructor('TestCases', 'cases') TestCases = Record.makeConstructor('TestCases', 'cases')
@ -323,13 +307,13 @@ class CommonTestSuite(unittest.TestCase):
raise Exception('Unsupported test kind', t.key) raise Exception('Unsupported test kind', t.key)
def DS(self, bs): def DS(self, bs):
return decode(bs, decode_embedded=Embedded) return decode(bs, decode_embedded=lambda x: x)
def D(self, bs): def D(self, bs):
return decode_with_annotations(bs, decode_embedded=Embedded) return decode_with_annotations(bs, decode_embedded=lambda x: x)
def E(self, v): def E(self, v):
return encode(v, encode_embedded=Embedded.value) return encode(v, encode_embedded=lambda x: x)
class RecordTests(unittest.TestCase): class RecordTests(unittest.TestCase):
def test_getters(self): def test_getters(self):

View File

@ -10,7 +10,7 @@ from .binary import Decoder
class TextCodec(object): pass class TextCodec(object): pass
class Parser(TextCodec): class Parser(TextCodec):
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=None): def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=lambda x: x):
super(Parser, self).__init__() super(Parser, self).__init__()
self.input_buffer = input_buffer self.input_buffer = input_buffer
self.index = 0 self.index = 0
@ -267,7 +267,9 @@ class Parser(TextCodec):
raise DecodeError('ByteString must follow #=') raise DecodeError('ByteString must follow #=')
return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next()) return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next())
if c == '!': if c == '!':
return self.wrap(self.parse_embedded(self.next())) if self.parse_embedded is None:
raise DecodeError('No parse_embedded function supplied')
return self.wrap(Embedded(self.parse_embedded(self.next())))
raise DecodeError('Invalid # syntax') raise DecodeError('Invalid # syntax')
if c == '<': if c == '<':
self.skip() self.skip()
@ -294,6 +296,15 @@ class Parser(TextCodec):
self.index = start self.index = start
return None return None
def __iter__(self):
return self
def __next__(self):
v = self.try_next()
if v is None:
raise StopIteration
return v
def parse(bs, **kwargs): def parse(bs, **kwargs):
return Parser(input_buffer=bs, **kwargs).next() return Parser(input_buffer=bs, **kwargs).next()
@ -301,10 +312,15 @@ def parse_with_annotations(bs, **kwargs):
return Parser(input_buffer=bs, include_annotations=True, **kwargs).next() return Parser(input_buffer=bs, include_annotations=True, **kwargs).next()
class Formatter(TextCodec): class Formatter(TextCodec):
def __init__(self, format_embedded=None): def __init__(self, format_embedded=lambda x: x):
super(Formatter, self).__init__() super(Formatter, self).__init__()
self.chunks = [] self.chunks = []
self.format_embedded = format_embedded self._format_embedded = format_embedded
def format_embedded(self, v):
if self._format_embedded is None:
raise EncodeError('No format_embedded function supplied')
return self._format_embedded(v)
def contents(self): def contents(self):
return u''.join(self.chunks) return u''.join(self.chunks)
@ -330,9 +346,7 @@ class Formatter(TextCodec):
self.chunks.append(closer) self.chunks.append(closer)
def append(self, v): def append(self, v):
while hasattr(v, '__preserve__'): v = preserve(v)
v = v.__preserve__()
if hasattr(v, '__preserve_write_text__'): if hasattr(v, '__preserve_write_text__'):
v.__preserve_write_text__(self) v.__preserve_write_text__(self)
elif v is False: elif v is False:
@ -375,9 +389,7 @@ class Formatter(TextCodec):
try: try:
i = iter(v) i = iter(v)
except TypeError: except TypeError:
self.chunks.append('#!') raise TypeError('Cannot preserves-format: ' + repr(v))
self.append(self.format_embedded(v))
return
self.write_seq('[', ']', i) self.write_seq('[', ']', i)
def stringify(v, **kwargs): def stringify(v, **kwargs):

View File

@ -4,6 +4,11 @@ import struct
from .error import DecodeError from .error import DecodeError
def preserve(v):
while hasattr(v, '__preserve__'):
v = v.__preserve__()
return v
class Float(object): class Float(object):
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
@ -276,3 +281,23 @@ def _unwrap(x):
return x.item return x.item
else: else:
return x return x
class Embedded:
def __init__(self, value):
self.embeddedValue = value
def __eq__(self, other):
other = _unwrap(other)
if other.__class__ is self.__class__:
return self.embeddedValue == other.embeddedValue
def __hash__(self):
return hash(self.embeddedValue)
def __preserve_write_binary__(self, encoder):
encoder.buffer.append(0x86)
encoder.append(encoder.encode_embedded(self.embeddedValue))
def __preserve_write_text__(self, formatter):
formatter.chunks.append('#!')
formatter.append(formatter.format_embedded(self.embeddedValue))