Python text codec

This commit is contained in:
Tony Garnock-Jones 2021-08-17 08:04:38 -04:00
parent 123b6222ca
commit 8afc8f1eae
7 changed files with 476 additions and 27 deletions

View File

@ -1,6 +1,7 @@
from .repr import Float, Symbol, Record, ImmutableDict
from .repr import Annotated, is_annotated, strip_annotations, annotate
from .values import Float, Symbol, Record, ImmutableDict
from .values import Annotated, is_annotated, strip_annotations, annotate
from .error import DecodeError, EncodeError, ShortPacket
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode
from .text import Parser, Formatter, parse, parse_with_annotations, stringify

View File

@ -1,13 +1,13 @@
import numbers
import struct
from .repr import *
from .values import *
from .error import *
from .compat import basestring_, ord_
class Codec(object): pass
class BinaryCodec(object): pass
class Decoder(Codec):
class Decoder(BinaryCodec):
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
super(Decoder, self).__init__()
self.packet = packet
@ -112,7 +112,7 @@ def decode(bs, **kwargs):
def decode_with_annotations(bs, **kwargs):
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
class Encoder(Codec):
class Encoder(BinaryCodec):
def __init__(self, encode_embedded=id):
super(Encoder, self).__init__()
self.buffer = bytearray()
@ -153,8 +153,11 @@ class Encoder(Codec):
self.buffer.extend(bs)
def append(self, v):
if hasattr(v, '__preserve_on__'):
v.__preserve_on__(self)
while hasattr(v, '__preserve__'):
v = v.__preserve__()
if hasattr(v, '__preserve_write_binary__'):
v.__preserve_write_binary__(self)
elif v is False:
self.buffer.append(0x80)
elif v is True:

View File

@ -7,3 +7,8 @@ if isinstance(chr(123), bytes):
ord_ = ord
else:
ord_ = lambda x: x
try:
unichr_ = unichr
except NameError:
unichr_ = chr

View File

@ -130,7 +130,7 @@ class SchemaObject:
return ()
raise ValueError('Bad schema')
def _encode(self):
def __preserve__(self):
raise NotImplementedError('Subclass responsibility')
def __repr__(self):
@ -173,7 +173,7 @@ class Enumeration(SchemaObject):
if i is not None: return i
return None
def _encode(self):
def __preserve__(self):
raise TypeError('Cannot encode instance of Enumeration')
def safeattrname(k):
@ -257,7 +257,7 @@ class Definition(SchemaObject):
if cls.parse(cls.SCHEMA, v, args) is not None: return cls(*args)
return None
def _encode(self):
def __preserve__(self):
if self.SIMPLE:
if self.EMPTY:
return encode(self.SCHEMA, ())
@ -293,7 +293,7 @@ def encode(p, v):
if p.key == DICTOF:
return dict((encode(p[0], k), encode(p[1], w)) for (k, w) in v.items())
if p.key == REF:
return v._encode()
return v.__preserve__()
if p.key == REC:
return Record(encode(p[0], v), encode(p[1], v))
if p.key == TUPLE:
@ -426,8 +426,8 @@ if __name__ == '__main__':
with open(__metaschema_filename, 'rb') as f:
x = Decoder(f.read()).next()
print(meta.Schema.decode(x))
print(meta.Schema.decode(x)._encode())
assert meta.Schema.decode(x)._encode() == x
print(meta.Schema.decode(x).__preserve__())
assert meta.Schema.decode(x).__preserve__() == x
@extend(meta.Schema)
def f(self, x):
@ -443,7 +443,7 @@ if __name__ == '__main__':
x = Decoder(f.read()).next()
print(meta.Schema.decode(x))
assert meta.Schema.decode(x) == meta.Schema.decode(x)
assert meta.Schema.decode(x)._encode() == x
assert meta.Schema.decode(x).__preserve__() == x
print()
print(path)

View File

@ -1,9 +1,11 @@
import numbers
import os
import sys
import unittest
from . import *
from .compat import basestring_, ord_
from .values import _unwrap
if isinstance(chr(123), bytes):
def _byte(x):
@ -44,7 +46,7 @@ def _e(v):
def _R(k, *args):
return Record(Symbol(k), args)
class CodecTests(unittest.TestCase):
class BinaryCodecTests(unittest.TestCase):
def _roundtrip(self, forward, expected, back=None, nondeterministic=False):
if back is None: back = forward
self.assertEqual(_d(_e(forward)), back)
@ -190,6 +192,25 @@ class CodecTests(unittest.TestCase):
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
(123, 234))
def load_binary_samples():
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f:
return Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next()
def load_text_samples():
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f:
return Parser(f.read(), include_annotations=True, parse_embedded=Embedded).next()
class TextCodecTests(unittest.TestCase):
def test_samples_bin_eq_txt(self):
b = load_binary_samples()
t = load_text_samples()
self.assertEqual(b, t)
def test_txt_roundtrip(self):
b = load_binary_samples()
s = stringify(b, format_embedded=Embedded.value)
self.assertEqual(parse(s, include_annotations=True, parse_embedded=Embedded), b)
def add_method(d, tName, fn):
if hasattr(fn, 'func_name'):
# python2
@ -265,6 +286,7 @@ class Embedded:
return i.v
def __eq__(self, other):
other = _unwrap(other)
if other.__class__ is self.__class__:
return self.v == other.v
@ -272,13 +294,9 @@ class Embedded:
return hash(self.v)
class CommonTestSuite(unittest.TestCase):
import os
with open(os.path.join(os.path.dirname(__file__),
'../../../tests/samples.bin'), 'rb') as f:
samples = Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next()
TestCases = Record.makeConstructor('TestCases', 'cases')
samples = load_binary_samples()
tests = TestCases._cases(samples.peel()).peel()
for (tName0, t0) in tests.items():
tName = tName0.strip().name

View File

@ -0,0 +1,386 @@
import numbers
import struct
import base64
from .values import *
from .error import *
from .compat import basestring_, unichr_
from .binary import Decoder
class TextCodec(object): pass
class Parser(TextCodec):
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=None):
super(Parser, self).__init__()
self.input_buffer = input_buffer
self.index = 0
self.include_annotations = include_annotations
self.parse_embedded = parse_embedded
def extend(self, text):
self.input_buffer = self.input_buffer[self.index:] + text
self.index = 0
def _atend(self):
return self.index >= len(self.input_buffer)
def peek(self):
if self._atend():
raise ShortPacket('Short input buffer')
return self.input_buffer[self.index]
def skip(self):
self.index = self.index + 1
def nextchar(self):
c = self.peek()
self.skip()
return c
def skip_whitespace(self):
while not self._atend():
c = self.peek()
if not (c.isspace() or c == ','):
break
self.skip()
def gather_annotations(self):
vs = []
while True:
self.skip_whitespace()
c = self.peek()
if c == ';':
self.skip()
vs.append(self.comment_line())
elif c == '@':
self.skip()
vs.append(self.next())
else:
return vs
def comment_line(self):
s = []
while True:
c = self.nextchar()
if c == '\r' or c == '\n':
return self.wrap(u''.join(s))
s.append(c)
def read_intpart(self, acc, c):
if c == '0':
acc.append(c)
else:
self.read_digit1(acc, c)
return self.read_fracexp(acc)
def read_fracexp(self, acc):
is_float = False
if self.peek() == '.':
is_float = True
acc.append(self.nextchar())
self.read_digit1(acc, self.nextchar())
if self.peek() in 'eE':
acc.append(self.nextchar())
return self.read_sign_and_exp(acc)
else:
return self.finish_number(acc, is_float)
def read_sign_and_exp(self, acc):
if self.peek() in '+-':
acc.append(self.nextchar())
self.read_digit1(acc, self.nextchar())
return self.finish_number(acc, True)
def finish_number(self, acc, is_float):
if is_float:
if self.peek() in 'fF':
self.skip()
return Float(float(u''.join(acc)))
else:
return float(u''.join(acc))
else:
return int(u''.join(acc))
def read_digit1(self, acc, c):
if not c.isdigit():
raise DecodeError('Incomplete number')
acc.append(c)
while not self._atend():
if not self.peek().isdigit():
break
acc.append(self.nextchar())
def read_stringlike(self, terminator, hexescape, hexescaper):
acc = []
while True:
c = self.nextchar()
if c == terminator:
return u''.join(acc)
if c == '\\':
c = self.nextchar()
if c == hexescape: hexescaper(acc)
elif c == terminator or c == '\\' or c == '/': acc.append(c)
elif c == 'b': acc.append(u'\x08')
elif c == 'f': acc.append(u'\x0c')
elif c == 'n': acc.append(u'\x0a')
elif c == 'r': acc.append(u'\x0d')
elif c == 't': acc.append(u'\x09')
else: raise DecodeError('Invalid escape code')
else:
acc.append(c)
def hexnum(self, count):
v = 0
for i in range(count):
c = self.nextchar().lower()
if c >= '0' and c <= '9':
v = v << 4 | (ord(c) - ord('0'))
elif c >= 'a' and c <= 'f':
v = v << 4 | (ord(c) - ord('a') + 10)
else:
raise DecodeError('Bad hex escape')
return v
def read_string(self, delimiter):
def u16_escape(acc):
n1 = self.hexnum(4)
if n1 >= 0xd800 and n1 <= 0xdbff:
ok = True
ok = ok and self.nextchar() == '\\'
ok = ok and self.nextchar() == 'u'
if not ok:
raise DecodeError('Missing second half of surrogate pair')
n2 = self.hexnum(4)
if n2 >= 0xdc00 and n2 <= 0xdfff:
n = ((n1 - 0xd800) << 10) + (n2 - 0xdc00) + 0x10000
acc.append(unichr_(n))
else:
raise DecodeError('Bad second half of surrogate pair')
else:
acc.append(unichr_(n1))
return self.read_stringlike(delimiter, 'u', u16_escape)
def read_literal_binary(self):
s = self.read_stringlike('"', 'x', lambda acc: acc.append(unichr_(self.hexnum(2))))
return s.encode('latin-1')
def read_hex_binary(self):
acc = bytearray()
while True:
self.skip_whitespace()
if self.peek() == '"':
self.skip()
return bytes(acc)
acc.append(self.hexnum(2))
def read_base64_binary(self):
acc = []
while True:
self.skip_whitespace()
c = self.nextchar()
if c == ']':
acc.append(u'====')
return base64.b64decode(u''.join(acc))
if c == '-': c = '+'
if c == '_': c = '/'
if c == '=': continue
acc.append(c)
def upto(self, delimiter):
vs = []
while True:
self.skip_whitespace()
if self.peek() == delimiter:
self.skip()
return tuple(vs)
vs.append(self.next())
def read_dictionary(self):
acc = []
while True:
self.skip_whitespace()
if self.peek() == '}':
self.skip()
return ImmutableDict.from_kvs(acc)
acc.append(self.next())
self.skip_whitespace()
if self.nextchar() != ':':
raise DecodeError('Missing expected key/value separator')
acc.append(self.next())
def read_raw_symbol(self, acc):
while not self._atend():
c = self.peek()
if c.isspace() or c in '(){}[]<>";,@#:|':
break
self.skip()
acc.append(c)
return Symbol(u''.join(acc))
def wrap(self, v):
return Annotated(v) if self.include_annotations else v
def next(self):
self.skip_whitespace()
c = self.peek()
if c == '-':
self.skip()
return self.wrap(self.read_intpart(['-'], self.nextchar()))
if c.isdigit():
self.skip()
return self.wrap(self.read_intpart([], c))
if c == '"':
self.skip()
return self.wrap(self.read_string('"'))
if c == '|':
self.skip()
return self.wrap(Symbol(self.read_string('|')))
if c in ';@':
annotations = self.gather_annotations()
v = self.next()
if self.include_annotations:
v.annotations = annotations + v.annotations
return v
if c == ':':
raise DecodeError('Unexpected key/value separator between items')
if c == '#':
self.skip()
c = self.nextchar()
if c == 'f': return self.wrap(False)
if c == 't': return self.wrap(True)
if c == '{': return self.wrap(frozenset(self.upto('}')))
if c == '"': return self.wrap(self.read_literal_binary())
if c == 'x':
if self.nextchar() != '"':
raise DecodeError('Expected open-quote at start of hex ByteString')
return self.wrap(self.read_hex_binary())
if c == '[': return self.wrap(self.read_base64_binary())
if c == '=':
old_ann = self.include_annotations
self.include_annotations = True
bs_val = self.next()
self.include_annotations = old_ann
if len(bs_val.annotations) > 0:
raise DecodeError('Annotations not permitted after #=')
bs_val = bs_val.item
if not isinstance(bs_val, bytes):
raise DecodeError('ByteString must follow #=')
return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next())
if c == '!':
return self.wrap(self.parse_embedded(self.next()))
raise DecodeError('Invalid # syntax')
if c == '<':
self.skip()
vs = self.upto('>')
if len(vs) == 0:
raise DecodeError('Missing record label')
return self.wrap(Record(vs[0], vs[1:]))
if c == '[':
self.skip()
return self.wrap(self.upto(']'))
if c == '{':
self.skip()
return self.wrap(self.read_dictionary())
if c in '>]}':
raise DecodeError('Unexpected ' + c)
self.skip()
return self.wrap(self.read_raw_symbol([c]))
def try_next(self):
start = self.index
try:
return self.next()
except ShortPacket:
self.index = start
return None
def parse(bs, **kwargs):
return Parser(input_buffer=bs, **kwargs).next()
def parse_with_annotations(bs, **kwargs):
return Parser(input_buffer=bs, include_annotations=True, **kwargs).next()
class Formatter(TextCodec):
def __init__(self, format_embedded=None):
super(Formatter, self).__init__()
self.chunks = []
self.format_embedded = format_embedded
def contents(self):
return u''.join(self.chunks)
def write_stringlike_char(self, c):
if c == '\\': self.chunks.append('\\\\')
elif c == '\x08': self.chunks.append('\\b')
elif c == '\x0c': self.chunks.append('\\f')
elif c == '\x0a': self.chunks.append('\\n')
elif c == '\x0d': self.chunks.append('\\r')
elif c == '\x09': self.chunks.append('\\t')
else: self.chunks.append(c)
def write_seq(self, opener, closer, vs):
self.chunks.append(opener)
first_item = True
for v in vs:
if first_item:
first_item = False
else:
self.chunks.append(' ')
self.append(v)
self.chunks.append(closer)
def append(self, v):
while hasattr(v, '__preserve__'):
v = v.__preserve__()
if hasattr(v, '__preserve_write_text__'):
v.__preserve_write_text__(self)
elif v is False:
self.chunks.append('#f')
elif v is True:
self.chunks.append('#t')
elif isinstance(v, float):
self.chunks.append(repr(v))
elif isinstance(v, numbers.Number):
self.chunks.append('%d' % (v,))
elif isinstance(v, bytes):
self.chunks.append('#[%s]' % (base64.b64encode(v).decode('ascii'),))
elif isinstance(v, basestring_):
self.chunks.append('"')
for c in v:
if c == '"': self.chunks.append('\\"')
else: self.write_stringlike_char(c)
self.chunks.append('"')
elif isinstance(v, list):
self.write_seq('[', ']', v)
elif isinstance(v, tuple):
self.write_seq('[', ']', v)
elif isinstance(v, set):
self.write_seq('#{', '}', v)
elif isinstance(v, frozenset):
self.write_seq('#{', '}', v)
elif isinstance(v, dict):
self.chunks.append('{')
need_comma = False
for (k, v) in v.items():
if need_comma:
self.chunks.append(', ')
else:
need_comma = True
self.append(k)
self.chunks.append(': ')
self.append(v)
self.chunks.append('}')
else:
try:
i = iter(v)
except TypeError:
self.chunks.append('#!')
self.append(self.format_embedded(v))
return
self.write_seq('[', ']', i)
def stringify(v, **kwargs):
e = Formatter(**kwargs)
e.append(v)
return e.contents()

View File

@ -1,3 +1,4 @@
import re
import sys
import struct
@ -8,6 +9,7 @@ class Float(object):
self.value = value
def __eq__(self, other):
other = _unwrap(other)
if other.__class__ is self.__class__:
return self.value == other.value
@ -20,15 +22,22 @@ class Float(object):
def __repr__(self):
return 'Float(' + repr(self.value) + ')'
def __preserve_on__(self, encoder):
def __preserve_write_binary__(self, encoder):
encoder.buffer.append(0x82)
encoder.buffer.extend(struct.pack('>f', self.value))
def __preserve_write_text__(self, formatter):
formatter.chunks.append(repr(self.value) + 'f')
# FIXME: This regular expression is conservatively correct, but Anglo-chauvinistic.
RAW_SYMBOL_RE = re.compile(r'^[a-zA-Z~!$%^&*?_=+/.][-a-zA-Z~!$%^&*?_=+/.0-9]*$')
class Symbol(object):
def __init__(self, name):
self.name = name.name if isinstance(name, Symbol) else name
def __eq__(self, other):
other = _unwrap(other)
return isinstance(other, Symbol) and self.name == other.name
def __ne__(self, other):
@ -40,12 +49,22 @@ class Symbol(object):
def __repr__(self):
return '#' + self.name
def __preserve_on__(self, encoder):
def __preserve_write_binary__(self, encoder):
bs = self.name.encode('utf-8')
encoder.buffer.append(0xb3)
encoder.varint(len(bs))
encoder.buffer.extend(bs)
def __preserve_write_text__(self, formatter):
if RAW_SYMBOL_RE.match(self.name):
formatter.chunks.append(self.name)
else:
formatter.chunks.append('|')
for c in self.name:
if c == '|': formatter.chunks.append('\\|')
else: formatter.write_stringlike_char(c)
formatter.chunks.append('|')
class Record(object):
def __init__(self, key, fields):
self.key = key
@ -53,6 +72,7 @@ class Record(object):
self.__hash = None
def __eq__(self, other):
other = _unwrap(other)
return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields)
def __ne__(self, other):
@ -66,13 +86,16 @@ class Record(object):
def __repr__(self):
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
def __preserve_on__(self, encoder):
def __preserve_write_binary__(self, encoder):
encoder.buffer.append(0xb4)
encoder.append(self.key)
for f in self.fields:
encoder.append(f)
encoder.buffer.append(0x84)
def __preserve_write_text__(self, formatter):
formatter.write_seq('<', '>', (self.key,) + self.fields)
def __getitem__(self, index):
return self.fields[index]
@ -114,6 +137,7 @@ class RecordConstructorInfo(object):
self.arity = arity
def __eq__(self, other):
other = _unwrap(other)
return isinstance(other, RecordConstructorInfo) and \
(self.key, self.arity) == (other.key, other.arity)
@ -180,12 +204,19 @@ class Annotated(object):
self.annotations = []
self.item = item
def __preserve_on__(self, encoder):
def __preserve_write_binary__(self, encoder):
for a in self.annotations:
encoder.buffer.append(0x85)
encoder.append(a)
encoder.append(self.item)
def __preserve_write_text__(self, formatter):
for a in self.annotations:
formatter.chunks.append('@')
formatter.append(a)
formatter.chunks.append(' ')
formatter.append(self.item)
def strip(self, depth=inf):
return strip_annotations(self, depth)
@ -193,8 +224,7 @@ class Annotated(object):
return strip_annotations(self, 1)
def __eq__(self, other):
if other.__class__ is self.__class__:
return self.item == other.item
return self.item == _unwrap(other)
def __ne__(self, other):
return not self.__eq__(other)
@ -240,3 +270,9 @@ def annotate(v, *anns):
for a in anns:
v.annotations.append(a)
return v
def _unwrap(x):
if is_annotated(x):
return x.item
else:
return x