diff --git a/implementations/python/preserves/__init__.py b/implementations/python/preserves/__init__.py index 506d8ad..9a1f59b 100644 --- a/implementations/python/preserves/__init__.py +++ b/implementations/python/preserves/__init__.py @@ -1,4 +1,4 @@ -from .values import Float, Symbol, Record, ImmutableDict +from .values import Float, Symbol, Record, ImmutableDict, Embedded, preserve from .values import Annotated, is_annotated, strip_annotations, annotate from .error import DecodeError, EncodeError, ShortPacket diff --git a/implementations/python/preserves/binary.py b/implementations/python/preserves/binary.py index 22721fe..5704b4c 100644 --- a/implementations/python/preserves/binary.py +++ b/implementations/python/preserves/binary.py @@ -8,7 +8,7 @@ from .compat import basestring_, ord_ class BinaryCodec(object): pass class Decoder(BinaryCodec): - def __init__(self, packet=b'', include_annotations=False, decode_embedded=None): + def __init__(self, packet=b'', include_annotations=False, decode_embedded=lambda x: x): super(Decoder, self).__init__() self.packet = packet self.index = 0 @@ -82,7 +82,7 @@ class Decoder(BinaryCodec): if tag == 0x86: if self.decode_embedded is None: raise DecodeError('No decode_embedded function supplied') - return self.wrap(self.decode_embedded(self.next())) + return self.wrap(Embedded(self.decode_embedded(self.next()))) if tag >= 0x90 and tag <= 0x9f: return self.wrap(tag - (0xa0 if tag > 0x9c else 0x90)) if tag >= 0xa0 and tag <= 0xaf: return self.wrap(self.nextint(tag - 0xa0 + 1)) if tag == 0xb0: return self.wrap(self.nextint(self.varint())) @@ -106,6 +106,15 @@ class Decoder(BinaryCodec): self.index = start return None + def __iter__(self): + return self + + def __next__(self): + v = self.try_next() + if v is None: + raise StopIteration + return v + def decode(bs, **kwargs): return Decoder(packet=bs, **kwargs).next() @@ -113,10 +122,15 @@ def decode_with_annotations(bs, **kwargs): return Decoder(packet=bs, include_annotations=True, **kwargs).next() class Encoder(BinaryCodec): - def __init__(self, encode_embedded=id): + def __init__(self, encode_embedded=lambda x: x): super(Encoder, self).__init__() self.buffer = bytearray() - self.encode_embedded = encode_embedded + self._encode_embedded = encode_embedded + + def encode_embedded(self, v): + if self._encode_embedded is None: + raise EncodeError('No encode_embedded function supplied') + return self._encode_embedded(v) def contents(self): return bytes(self.buffer) @@ -153,9 +167,7 @@ class Encoder(BinaryCodec): self.buffer.extend(bs) def append(self, v): - while hasattr(v, '__preserve__'): - v = v.__preserve__() - + v = preserve(v) if hasattr(v, '__preserve_write_binary__'): v.__preserve_write_binary__(self) elif v is False: @@ -188,9 +200,7 @@ class Encoder(BinaryCodec): try: i = iter(v) except TypeError: - self.buffer.append(0x86) - self.append(self.encode_embedded(v)) - return + raise TypeError('Cannot preserves-encode: ' + repr(v)) self.encodevalues(5, i) def encode(v, **kwargs): diff --git a/implementations/python/preserves/schema.py b/implementations/python/preserves/schema.py index 34e6882..965b742 100644 --- a/implementations/python/preserves/schema.py +++ b/implementations/python/preserves/schema.py @@ -1,4 +1,4 @@ -from .preserves import * +from . import * import pathlib import keyword @@ -66,7 +66,8 @@ class SchemaObject: if k == SYMBOL and isinstance(v, Symbol): return v return None if p.key == EMBEDDED: - return v ## TODO: reconsider representation of embedded values? + if not isinstance(v, Embedded): return None + return v.embeddedValue if p.key == LIT: if v == p[0]: return () return None @@ -283,7 +284,7 @@ def encode(p, v): if p.key == ATOM: return v if p.key == EMBEDDED: - return v ## TODO: reconsider representation of embedded values? + return Embedded(v) if p.key == LIT: return p[0] if p.key == SEQOF: @@ -417,7 +418,10 @@ def load_schema_file(filename): # a decorator def extend(cls): - return lambda f: setattr(cls, f.__name__, f) + def extender(f): + setattr(cls, f.__name__, f) + return f + return extender __metaschema_filename = pathlib.Path(__file__).parent / '../../../schema/schema.bin' meta = load_schema_file(__metaschema_filename).schema diff --git a/implementations/python/preserves/test_preserves.py b/implementations/python/preserves/test_preserves.py index 0b95e90..5c82ecd 100644 --- a/implementations/python/preserves/test_preserves.py +++ b/implementations/python/preserves/test_preserves.py @@ -166,23 +166,23 @@ class BinaryCodecTests(unittest.TestCase): class A: def __init__(self, a): self.a = a - a1 = A(1) - a2 = A(1) - self.assertNotEqual(_e(a1), _e(a2)) - self.assertEqual(_e(a1), _e(a1)) - self.assertEqual(ord_(_e(a1)[0]), 0x86) - self.assertEqual(ord_(_e(a2)[0]), 0x86) + a1 = Embedded(A(1)) + a2 = Embedded(A(1)) + self.assertNotEqual(encode(a1, encode_embedded=id), encode(a2, encode_embedded=id)) + self.assertEqual(encode(a1, encode_embedded=id), encode(a1, encode_embedded=id)) + self.assertEqual(ord_(encode(a1, encode_embedded=id)[0]), 0x86) + self.assertEqual(ord_(encode(a2, encode_embedded=id)[0]), 0x86) def test_decode_embedded_absent(self): with self.assertRaises(DecodeError): - decode(b'\x86\xa0\xff') + decode(b'\x86\xa0\xff', decode_embedded=None) def test_encode_embedded(self): objects = [] def enc(p): objects.append(p) return len(objects) - 1 - self.assertEqual(encode([object(), object()], encode_embedded = enc), + self.assertEqual(encode([Embedded(object()), Embedded(object())], encode_embedded = enc), b'\xb5\x86\x90\x86\x91\x84') def test_decode_embedded(self): @@ -190,15 +190,15 @@ class BinaryCodecTests(unittest.TestCase): def dec(v): return objects[v] self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec), - (123, 234)) + (Embedded(123), Embedded(234))) def load_binary_samples(): with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f: - return Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next() + return Decoder(f.read(), include_annotations=True, decode_embedded=lambda x: x).next() def load_text_samples(): with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f: - return Parser(f.read(), include_annotations=True, parse_embedded=Embedded).next() + return Parser(f.read(), include_annotations=True, parse_embedded=lambda x: x).next() class TextCodecTests(unittest.TestCase): def test_samples_bin_eq_txt(self): @@ -208,8 +208,8 @@ class TextCodecTests(unittest.TestCase): def test_txt_roundtrip(self): b = load_binary_samples() - s = stringify(b, format_embedded=Embedded.value) - self.assertEqual(parse(s, include_annotations=True, parse_embedded=Embedded), b) + s = stringify(b, format_embedded=lambda x: x) + self.assertEqual(parse(s, include_annotations=True, parse_embedded=lambda x: x), b) def add_method(d, tName, fn): if hasattr(fn, 'func_name'): @@ -277,22 +277,6 @@ def install_exn_test(d, tName, bs, check_proc): self.fail('did not fail as expected') add_method(d, tName, test_exn) -class Embedded: - def __init__(self, v): - self.v = strip_annotations(v) - - @staticmethod - def value(i): - return i.v - - def __eq__(self, other): - other = _unwrap(other) - if other.__class__ is self.__class__: - return self.v == other.v - - def __hash__(self): - return hash(self.v) - class CommonTestSuite(unittest.TestCase): TestCases = Record.makeConstructor('TestCases', 'cases') @@ -323,13 +307,13 @@ class CommonTestSuite(unittest.TestCase): raise Exception('Unsupported test kind', t.key) def DS(self, bs): - return decode(bs, decode_embedded=Embedded) + return decode(bs, decode_embedded=lambda x: x) def D(self, bs): - return decode_with_annotations(bs, decode_embedded=Embedded) + return decode_with_annotations(bs, decode_embedded=lambda x: x) def E(self, v): - return encode(v, encode_embedded=Embedded.value) + return encode(v, encode_embedded=lambda x: x) class RecordTests(unittest.TestCase): def test_getters(self): diff --git a/implementations/python/preserves/text.py b/implementations/python/preserves/text.py index 935fbbb..c36a739 100644 --- a/implementations/python/preserves/text.py +++ b/implementations/python/preserves/text.py @@ -10,7 +10,7 @@ from .binary import Decoder class TextCodec(object): pass class Parser(TextCodec): - def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=None): + def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=lambda x: x): super(Parser, self).__init__() self.input_buffer = input_buffer self.index = 0 @@ -267,7 +267,9 @@ class Parser(TextCodec): raise DecodeError('ByteString must follow #=') return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next()) if c == '!': - return self.wrap(self.parse_embedded(self.next())) + if self.parse_embedded is None: + raise DecodeError('No parse_embedded function supplied') + return self.wrap(Embedded(self.parse_embedded(self.next()))) raise DecodeError('Invalid # syntax') if c == '<': self.skip() @@ -294,6 +296,15 @@ class Parser(TextCodec): self.index = start return None + def __iter__(self): + return self + + def __next__(self): + v = self.try_next() + if v is None: + raise StopIteration + return v + def parse(bs, **kwargs): return Parser(input_buffer=bs, **kwargs).next() @@ -301,10 +312,15 @@ def parse_with_annotations(bs, **kwargs): return Parser(input_buffer=bs, include_annotations=True, **kwargs).next() class Formatter(TextCodec): - def __init__(self, format_embedded=None): + def __init__(self, format_embedded=lambda x: x): super(Formatter, self).__init__() self.chunks = [] - self.format_embedded = format_embedded + self._format_embedded = format_embedded + + def format_embedded(self, v): + if self._format_embedded is None: + raise EncodeError('No format_embedded function supplied') + return self._format_embedded(v) def contents(self): return u''.join(self.chunks) @@ -330,9 +346,7 @@ class Formatter(TextCodec): self.chunks.append(closer) def append(self, v): - while hasattr(v, '__preserve__'): - v = v.__preserve__() - + v = preserve(v) if hasattr(v, '__preserve_write_text__'): v.__preserve_write_text__(self) elif v is False: @@ -375,9 +389,7 @@ class Formatter(TextCodec): try: i = iter(v) except TypeError: - self.chunks.append('#!') - self.append(self.format_embedded(v)) - return + raise TypeError('Cannot preserves-format: ' + repr(v)) self.write_seq('[', ']', i) def stringify(v, **kwargs): diff --git a/implementations/python/preserves/values.py b/implementations/python/preserves/values.py index c617bf1..e51f394 100644 --- a/implementations/python/preserves/values.py +++ b/implementations/python/preserves/values.py @@ -4,6 +4,11 @@ import struct from .error import DecodeError +def preserve(v): + while hasattr(v, '__preserve__'): + v = v.__preserve__() + return v + class Float(object): def __init__(self, value): self.value = value @@ -276,3 +281,23 @@ def _unwrap(x): return x.item else: return x + +class Embedded: + def __init__(self, value): + self.embeddedValue = value + + def __eq__(self, other): + other = _unwrap(other) + if other.__class__ is self.__class__: + return self.embeddedValue == other.embeddedValue + + def __hash__(self): + return hash(self.embeddedValue) + + def __preserve_write_binary__(self, encoder): + encoder.buffer.append(0x86) + encoder.append(encoder.encode_embedded(self.embeddedValue)) + + def __preserve_write_text__(self, formatter): + formatter.chunks.append('#!') + formatter.append(formatter.format_embedded(self.embeddedValue))