From 21de8b799a1695f77b1a2e7b7fe0448cbf4c8e2c Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Wed, 9 Feb 2022 14:59:50 +0100 Subject: [PATCH] Canonicalizing binary encoder --- implementations/python/preserves/__init__.py | 2 +- implementations/python/preserves/binary.py | 47 +++++++++++++++++-- .../python/tests/test_preserves.py | 9 +++- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/implementations/python/preserves/__init__.py b/implementations/python/preserves/__init__.py index 393e2b9..7a163ef 100644 --- a/implementations/python/preserves/__init__.py +++ b/implementations/python/preserves/__init__.py @@ -3,7 +3,7 @@ from .values import Annotated, is_annotated, strip_annotations, annotate from .error import DecodeError, EncodeError, ShortPacket -from .binary import Decoder, Encoder, decode, decode_with_annotations, encode +from .binary import Decoder, Encoder, decode, decode_with_annotations, encode, canonicalize from .text import Parser, Formatter, parse, parse_with_annotations, stringify from . import fold diff --git a/implementations/python/preserves/binary.py b/implementations/python/preserves/binary.py index 25745ad..78ec1dd 100644 --- a/implementations/python/preserves/binary.py +++ b/implementations/python/preserves/binary.py @@ -122,10 +122,14 @@ def decode_with_annotations(bs, **kwargs): return Decoder(packet=bs, include_annotations=True, **kwargs).next() class Encoder(BinaryCodec): - def __init__(self, encode_embedded=lambda x: x): + def __init__(self, encode_embedded=lambda x: x, canonicalize=False): super(Encoder, self).__init__() self.buffer = bytearray() self._encode_embedded = encode_embedded + self._canonicalize = canonicalize + + def reset(self): + self.buffer = bytearray() def encode_embedded(self, v): if self._encode_embedded is None: @@ -166,6 +170,22 @@ class Encoder(BinaryCodec): self.varint(len(bs)) self.buffer.extend(bs) + def encodeset(self, v): + if not self._canonicalize: + self.encodevalues(6, v) + else: + c = Canonicalizer(self._encode_embedded) + for i in v: c.entry([i]) + c.emit_entries(self, 6) + + def encodedict(self, v): + if not self._canonicalize: + self.encodevalues(7, list(dict_kvs(v))) + else: + c = Canonicalizer(self._encode_embedded) + for (kk, vv) in v.items(): c.entry([kk, vv]) + c.emit_entries(self, 7) + def append(self, v): v = preserve(v) if hasattr(v, '__preserve_write_binary__'): @@ -191,11 +211,11 @@ class Encoder(BinaryCodec): elif isinstance(v, tuple): self.encodevalues(5, v) elif isinstance(v, set): - self.encodevalues(6, v) + self.encodeset(v) elif isinstance(v, frozenset): - self.encodevalues(6, v) + self.encodeset(v) elif isinstance(v, dict): - self.encodevalues(7, list(dict_kvs(v))) + self.encodedict(v) else: try: i = iter(v) @@ -209,7 +229,26 @@ class Encoder(BinaryCodec): def cannot_encode(self, v): raise TypeError('Cannot preserves-encode: ' + repr(v)) +class Canonicalizer: + def __init__(self, encode_embedded): + self.encoder = Encoder(encode_embedded, canonicalize=True) + self.entries = [] + + def entry(self, pieces): + for piece in pieces: self.encoder.append(piece) + entry = self.encoder.contents() + self.encoder.reset() + self.entries.append(entry) + + def emit_entries(self, outer_encoder, tag): + outer_encoder.buffer.append(0xb0 + tag) + for e in sorted(self.entries): outer_encoder.buffer.extend(e) + outer_encoder.buffer.append(0x84) + def encode(v, **kwargs): e = Encoder(**kwargs) e.append(v) return e.contents() + +def canonicalize(v, **kwargs): + return encode(v, canonicalize=True, **kwargs) diff --git a/implementations/python/tests/test_preserves.py b/implementations/python/tests/test_preserves.py index 98b13a1..9dc2f9c 100644 --- a/implementations/python/tests/test_preserves.py +++ b/implementations/python/tests/test_preserves.py @@ -260,14 +260,18 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm): def test_back(self): self.assertEqual(self.DS(binaryForm), back) def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm) def test_encode(self): self.assertEqual(self.E(forward), binaryForm) + def test_encode_canonical(self): self.assertEqual(self.EC(annotatedTextForm), binaryForm) def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm) add_method(d, tName, test_match_expected) add_method(d, tName, test_roundtrip) add_method(d, tName, test_forward) add_method(d, tName, test_back) add_method(d, tName, test_back_ann) - if variant not in ['decode', 'nondeterministic']: + if variant in ['normal']: add_method(d, tName, test_encode) + if variant in ['nondeterministic']: + add_method(d, tName, test_encode_canonical) + if variant in ['normal', 'nondeterministic']: add_method(d, tName, test_encode_ann) def install_exn_test(d, tName, bs, check_proc): @@ -318,6 +322,9 @@ class CommonTestSuite(unittest.TestCase): def E(self, v): return encode(v, encode_embedded=lambda x: x) + def EC(self, v): + return encode(v, encode_embedded=lambda x: x, canonicalize=True) + class RecordTests(unittest.TestCase): def test_getters(self): T = Record.makeConstructor('t', 'x y z')