Canonicalizing binary encoder

This commit is contained in:
Tony Garnock-Jones 2022-02-09 14:59:50 +01:00
parent 0596d877f8
commit 21de8b799a
3 changed files with 52 additions and 6 deletions

View File

@ -3,7 +3,7 @@ from .values import Annotated, is_annotated, strip_annotations, annotate
from .error import DecodeError, EncodeError, ShortPacket
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode, canonicalize
from .text import Parser, Formatter, parse, parse_with_annotations, stringify
from . import fold

View File

@ -122,10 +122,14 @@ def decode_with_annotations(bs, **kwargs):
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
class Encoder(BinaryCodec):
def __init__(self, encode_embedded=lambda x: x):
def __init__(self, encode_embedded=lambda x: x, canonicalize=False):
super(Encoder, self).__init__()
self.buffer = bytearray()
self._encode_embedded = encode_embedded
self._canonicalize = canonicalize
def reset(self):
self.buffer = bytearray()
def encode_embedded(self, v):
if self._encode_embedded is None:
@ -166,6 +170,22 @@ class Encoder(BinaryCodec):
self.varint(len(bs))
self.buffer.extend(bs)
def encodeset(self, v):
if not self._canonicalize:
self.encodevalues(6, v)
else:
c = Canonicalizer(self._encode_embedded)
for i in v: c.entry([i])
c.emit_entries(self, 6)
def encodedict(self, v):
if not self._canonicalize:
self.encodevalues(7, list(dict_kvs(v)))
else:
c = Canonicalizer(self._encode_embedded)
for (kk, vv) in v.items(): c.entry([kk, vv])
c.emit_entries(self, 7)
def append(self, v):
v = preserve(v)
if hasattr(v, '__preserve_write_binary__'):
@ -191,11 +211,11 @@ class Encoder(BinaryCodec):
elif isinstance(v, tuple):
self.encodevalues(5, v)
elif isinstance(v, set):
self.encodevalues(6, v)
self.encodeset(v)
elif isinstance(v, frozenset):
self.encodevalues(6, v)
self.encodeset(v)
elif isinstance(v, dict):
self.encodevalues(7, list(dict_kvs(v)))
self.encodedict(v)
else:
try:
i = iter(v)
@ -209,7 +229,26 @@ class Encoder(BinaryCodec):
def cannot_encode(self, v):
raise TypeError('Cannot preserves-encode: ' + repr(v))
class Canonicalizer:
def __init__(self, encode_embedded):
self.encoder = Encoder(encode_embedded, canonicalize=True)
self.entries = []
def entry(self, pieces):
for piece in pieces: self.encoder.append(piece)
entry = self.encoder.contents()
self.encoder.reset()
self.entries.append(entry)
def emit_entries(self, outer_encoder, tag):
outer_encoder.buffer.append(0xb0 + tag)
for e in sorted(self.entries): outer_encoder.buffer.extend(e)
outer_encoder.buffer.append(0x84)
def encode(v, **kwargs):
e = Encoder(**kwargs)
e.append(v)
return e.contents()
def canonicalize(v, **kwargs):
return encode(v, canonicalize=True, **kwargs)

View File

@ -260,14 +260,18 @@ def install_test(d, variant, tName, binaryForm, annotatedTextForm):
def test_back(self): self.assertEqual(self.DS(binaryForm), back)
def test_back_ann(self): self.assertEqual(self.D(self.E(annotatedTextForm)), annotatedTextForm)
def test_encode(self): self.assertEqual(self.E(forward), binaryForm)
def test_encode_canonical(self): self.assertEqual(self.EC(annotatedTextForm), binaryForm)
def test_encode_ann(self): self.assertEqual(self.E(annotatedTextForm), binaryForm)
add_method(d, tName, test_match_expected)
add_method(d, tName, test_roundtrip)
add_method(d, tName, test_forward)
add_method(d, tName, test_back)
add_method(d, tName, test_back_ann)
if variant not in ['decode', 'nondeterministic']:
if variant in ['normal']:
add_method(d, tName, test_encode)
if variant in ['nondeterministic']:
add_method(d, tName, test_encode_canonical)
if variant in ['normal', 'nondeterministic']:
add_method(d, tName, test_encode_ann)
def install_exn_test(d, tName, bs, check_proc):
@ -318,6 +322,9 @@ class CommonTestSuite(unittest.TestCase):
def E(self, v):
return encode(v, encode_embedded=lambda x: x)
def EC(self, v):
return encode(v, encode_embedded=lambda x: x, canonicalize=True)
class RecordTests(unittest.TestCase):
def test_getters(self):
T = Record.makeConstructor('t', 'x y z')