forked from syndicate-lang/preserves
Python text codec
This commit is contained in:
parent
123b6222ca
commit
8afc8f1eae
|
@ -1,6 +1,7 @@
|
|||
from .repr import Float, Symbol, Record, ImmutableDict
|
||||
from .repr import Annotated, is_annotated, strip_annotations, annotate
|
||||
from .values import Float, Symbol, Record, ImmutableDict
|
||||
from .values import Annotated, is_annotated, strip_annotations, annotate
|
||||
|
||||
from .error import DecodeError, EncodeError, ShortPacket
|
||||
|
||||
from .binary import Decoder, Encoder, decode, decode_with_annotations, encode
|
||||
from .text import Parser, Formatter, parse, parse_with_annotations, stringify
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import numbers
|
||||
import struct
|
||||
|
||||
from .repr import *
|
||||
from .values import *
|
||||
from .error import *
|
||||
from .compat import basestring_, ord_
|
||||
|
||||
class Codec(object): pass
|
||||
class BinaryCodec(object): pass
|
||||
|
||||
class Decoder(Codec):
|
||||
class Decoder(BinaryCodec):
|
||||
def __init__(self, packet=b'', include_annotations=False, decode_embedded=None):
|
||||
super(Decoder, self).__init__()
|
||||
self.packet = packet
|
||||
|
@ -112,7 +112,7 @@ def decode(bs, **kwargs):
|
|||
def decode_with_annotations(bs, **kwargs):
|
||||
return Decoder(packet=bs, include_annotations=True, **kwargs).next()
|
||||
|
||||
class Encoder(Codec):
|
||||
class Encoder(BinaryCodec):
|
||||
def __init__(self, encode_embedded=id):
|
||||
super(Encoder, self).__init__()
|
||||
self.buffer = bytearray()
|
||||
|
@ -153,8 +153,11 @@ class Encoder(Codec):
|
|||
self.buffer.extend(bs)
|
||||
|
||||
def append(self, v):
|
||||
if hasattr(v, '__preserve_on__'):
|
||||
v.__preserve_on__(self)
|
||||
while hasattr(v, '__preserve__'):
|
||||
v = v.__preserve__()
|
||||
|
||||
if hasattr(v, '__preserve_write_binary__'):
|
||||
v.__preserve_write_binary__(self)
|
||||
elif v is False:
|
||||
self.buffer.append(0x80)
|
||||
elif v is True:
|
||||
|
|
|
@ -7,3 +7,8 @@ if isinstance(chr(123), bytes):
|
|||
ord_ = ord
|
||||
else:
|
||||
ord_ = lambda x: x
|
||||
|
||||
try:
|
||||
unichr_ = unichr
|
||||
except NameError:
|
||||
unichr_ = chr
|
||||
|
|
|
@ -130,7 +130,7 @@ class SchemaObject:
|
|||
return ()
|
||||
raise ValueError('Bad schema')
|
||||
|
||||
def _encode(self):
|
||||
def __preserve__(self):
|
||||
raise NotImplementedError('Subclass responsibility')
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -173,7 +173,7 @@ class Enumeration(SchemaObject):
|
|||
if i is not None: return i
|
||||
return None
|
||||
|
||||
def _encode(self):
|
||||
def __preserve__(self):
|
||||
raise TypeError('Cannot encode instance of Enumeration')
|
||||
|
||||
def safeattrname(k):
|
||||
|
@ -257,7 +257,7 @@ class Definition(SchemaObject):
|
|||
if cls.parse(cls.SCHEMA, v, args) is not None: return cls(*args)
|
||||
return None
|
||||
|
||||
def _encode(self):
|
||||
def __preserve__(self):
|
||||
if self.SIMPLE:
|
||||
if self.EMPTY:
|
||||
return encode(self.SCHEMA, ())
|
||||
|
@ -293,7 +293,7 @@ def encode(p, v):
|
|||
if p.key == DICTOF:
|
||||
return dict((encode(p[0], k), encode(p[1], w)) for (k, w) in v.items())
|
||||
if p.key == REF:
|
||||
return v._encode()
|
||||
return v.__preserve__()
|
||||
if p.key == REC:
|
||||
return Record(encode(p[0], v), encode(p[1], v))
|
||||
if p.key == TUPLE:
|
||||
|
@ -426,8 +426,8 @@ if __name__ == '__main__':
|
|||
with open(__metaschema_filename, 'rb') as f:
|
||||
x = Decoder(f.read()).next()
|
||||
print(meta.Schema.decode(x))
|
||||
print(meta.Schema.decode(x)._encode())
|
||||
assert meta.Schema.decode(x)._encode() == x
|
||||
print(meta.Schema.decode(x).__preserve__())
|
||||
assert meta.Schema.decode(x).__preserve__() == x
|
||||
|
||||
@extend(meta.Schema)
|
||||
def f(self, x):
|
||||
|
@ -443,7 +443,7 @@ if __name__ == '__main__':
|
|||
x = Decoder(f.read()).next()
|
||||
print(meta.Schema.decode(x))
|
||||
assert meta.Schema.decode(x) == meta.Schema.decode(x)
|
||||
assert meta.Schema.decode(x)._encode() == x
|
||||
assert meta.Schema.decode(x).__preserve__() == x
|
||||
|
||||
print()
|
||||
print(path)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import numbers
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from . import *
|
||||
from .compat import basestring_, ord_
|
||||
from .values import _unwrap
|
||||
|
||||
if isinstance(chr(123), bytes):
|
||||
def _byte(x):
|
||||
|
@ -44,7 +46,7 @@ def _e(v):
|
|||
def _R(k, *args):
|
||||
return Record(Symbol(k), args)
|
||||
|
||||
class CodecTests(unittest.TestCase):
|
||||
class BinaryCodecTests(unittest.TestCase):
|
||||
def _roundtrip(self, forward, expected, back=None, nondeterministic=False):
|
||||
if back is None: back = forward
|
||||
self.assertEqual(_d(_e(forward)), back)
|
||||
|
@ -190,6 +192,25 @@ class CodecTests(unittest.TestCase):
|
|||
self.assertEqual(decode(b'\xb5\x86\x90\x86\x91\x84', decode_embedded = dec),
|
||||
(123, 234))
|
||||
|
||||
def load_binary_samples():
|
||||
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.bin'), 'rb') as f:
|
||||
return Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next()
|
||||
|
||||
def load_text_samples():
|
||||
with open(os.path.join(os.path.dirname(__file__), '../../../tests/samples.pr'), 'rt') as f:
|
||||
return Parser(f.read(), include_annotations=True, parse_embedded=Embedded).next()
|
||||
|
||||
class TextCodecTests(unittest.TestCase):
|
||||
def test_samples_bin_eq_txt(self):
|
||||
b = load_binary_samples()
|
||||
t = load_text_samples()
|
||||
self.assertEqual(b, t)
|
||||
|
||||
def test_txt_roundtrip(self):
|
||||
b = load_binary_samples()
|
||||
s = stringify(b, format_embedded=Embedded.value)
|
||||
self.assertEqual(parse(s, include_annotations=True, parse_embedded=Embedded), b)
|
||||
|
||||
def add_method(d, tName, fn):
|
||||
if hasattr(fn, 'func_name'):
|
||||
# python2
|
||||
|
@ -265,6 +286,7 @@ class Embedded:
|
|||
return i.v
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
if other.__class__ is self.__class__:
|
||||
return self.v == other.v
|
||||
|
||||
|
@ -272,13 +294,9 @@ class Embedded:
|
|||
return hash(self.v)
|
||||
|
||||
class CommonTestSuite(unittest.TestCase):
|
||||
import os
|
||||
with open(os.path.join(os.path.dirname(__file__),
|
||||
'../../../tests/samples.bin'), 'rb') as f:
|
||||
samples = Decoder(f.read(), include_annotations=True, decode_embedded=Embedded).next()
|
||||
|
||||
TestCases = Record.makeConstructor('TestCases', 'cases')
|
||||
|
||||
samples = load_binary_samples()
|
||||
tests = TestCases._cases(samples.peel()).peel()
|
||||
for (tName0, t0) in tests.items():
|
||||
tName = tName0.strip().name
|
||||
|
|
|
@ -0,0 +1,386 @@
|
|||
import numbers
|
||||
import struct
|
||||
import base64
|
||||
|
||||
from .values import *
|
||||
from .error import *
|
||||
from .compat import basestring_, unichr_
|
||||
from .binary import Decoder
|
||||
|
||||
class TextCodec(object): pass
|
||||
|
||||
class Parser(TextCodec):
|
||||
def __init__(self, input_buffer=u'', include_annotations=False, parse_embedded=None):
|
||||
super(Parser, self).__init__()
|
||||
self.input_buffer = input_buffer
|
||||
self.index = 0
|
||||
self.include_annotations = include_annotations
|
||||
self.parse_embedded = parse_embedded
|
||||
|
||||
def extend(self, text):
|
||||
self.input_buffer = self.input_buffer[self.index:] + text
|
||||
self.index = 0
|
||||
|
||||
def _atend(self):
|
||||
return self.index >= len(self.input_buffer)
|
||||
|
||||
def peek(self):
|
||||
if self._atend():
|
||||
raise ShortPacket('Short input buffer')
|
||||
return self.input_buffer[self.index]
|
||||
|
||||
def skip(self):
|
||||
self.index = self.index + 1
|
||||
|
||||
def nextchar(self):
|
||||
c = self.peek()
|
||||
self.skip()
|
||||
return c
|
||||
|
||||
def skip_whitespace(self):
|
||||
while not self._atend():
|
||||
c = self.peek()
|
||||
if not (c.isspace() or c == ','):
|
||||
break
|
||||
self.skip()
|
||||
|
||||
def gather_annotations(self):
|
||||
vs = []
|
||||
while True:
|
||||
self.skip_whitespace()
|
||||
c = self.peek()
|
||||
if c == ';':
|
||||
self.skip()
|
||||
vs.append(self.comment_line())
|
||||
elif c == '@':
|
||||
self.skip()
|
||||
vs.append(self.next())
|
||||
else:
|
||||
return vs
|
||||
|
||||
def comment_line(self):
|
||||
s = []
|
||||
while True:
|
||||
c = self.nextchar()
|
||||
if c == '\r' or c == '\n':
|
||||
return self.wrap(u''.join(s))
|
||||
s.append(c)
|
||||
|
||||
def read_intpart(self, acc, c):
|
||||
if c == '0':
|
||||
acc.append(c)
|
||||
else:
|
||||
self.read_digit1(acc, c)
|
||||
return self.read_fracexp(acc)
|
||||
|
||||
def read_fracexp(self, acc):
|
||||
is_float = False
|
||||
if self.peek() == '.':
|
||||
is_float = True
|
||||
acc.append(self.nextchar())
|
||||
self.read_digit1(acc, self.nextchar())
|
||||
if self.peek() in 'eE':
|
||||
acc.append(self.nextchar())
|
||||
return self.read_sign_and_exp(acc)
|
||||
else:
|
||||
return self.finish_number(acc, is_float)
|
||||
|
||||
def read_sign_and_exp(self, acc):
|
||||
if self.peek() in '+-':
|
||||
acc.append(self.nextchar())
|
||||
self.read_digit1(acc, self.nextchar())
|
||||
return self.finish_number(acc, True)
|
||||
|
||||
def finish_number(self, acc, is_float):
|
||||
if is_float:
|
||||
if self.peek() in 'fF':
|
||||
self.skip()
|
||||
return Float(float(u''.join(acc)))
|
||||
else:
|
||||
return float(u''.join(acc))
|
||||
else:
|
||||
return int(u''.join(acc))
|
||||
|
||||
def read_digit1(self, acc, c):
|
||||
if not c.isdigit():
|
||||
raise DecodeError('Incomplete number')
|
||||
acc.append(c)
|
||||
while not self._atend():
|
||||
if not self.peek().isdigit():
|
||||
break
|
||||
acc.append(self.nextchar())
|
||||
|
||||
def read_stringlike(self, terminator, hexescape, hexescaper):
|
||||
acc = []
|
||||
while True:
|
||||
c = self.nextchar()
|
||||
if c == terminator:
|
||||
return u''.join(acc)
|
||||
if c == '\\':
|
||||
c = self.nextchar()
|
||||
if c == hexescape: hexescaper(acc)
|
||||
elif c == terminator or c == '\\' or c == '/': acc.append(c)
|
||||
elif c == 'b': acc.append(u'\x08')
|
||||
elif c == 'f': acc.append(u'\x0c')
|
||||
elif c == 'n': acc.append(u'\x0a')
|
||||
elif c == 'r': acc.append(u'\x0d')
|
||||
elif c == 't': acc.append(u'\x09')
|
||||
else: raise DecodeError('Invalid escape code')
|
||||
else:
|
||||
acc.append(c)
|
||||
|
||||
def hexnum(self, count):
|
||||
v = 0
|
||||
for i in range(count):
|
||||
c = self.nextchar().lower()
|
||||
if c >= '0' and c <= '9':
|
||||
v = v << 4 | (ord(c) - ord('0'))
|
||||
elif c >= 'a' and c <= 'f':
|
||||
v = v << 4 | (ord(c) - ord('a') + 10)
|
||||
else:
|
||||
raise DecodeError('Bad hex escape')
|
||||
return v
|
||||
|
||||
def read_string(self, delimiter):
|
||||
def u16_escape(acc):
|
||||
n1 = self.hexnum(4)
|
||||
if n1 >= 0xd800 and n1 <= 0xdbff:
|
||||
ok = True
|
||||
ok = ok and self.nextchar() == '\\'
|
||||
ok = ok and self.nextchar() == 'u'
|
||||
if not ok:
|
||||
raise DecodeError('Missing second half of surrogate pair')
|
||||
n2 = self.hexnum(4)
|
||||
if n2 >= 0xdc00 and n2 <= 0xdfff:
|
||||
n = ((n1 - 0xd800) << 10) + (n2 - 0xdc00) + 0x10000
|
||||
acc.append(unichr_(n))
|
||||
else:
|
||||
raise DecodeError('Bad second half of surrogate pair')
|
||||
else:
|
||||
acc.append(unichr_(n1))
|
||||
return self.read_stringlike(delimiter, 'u', u16_escape)
|
||||
|
||||
def read_literal_binary(self):
|
||||
s = self.read_stringlike('"', 'x', lambda acc: acc.append(unichr_(self.hexnum(2))))
|
||||
return s.encode('latin-1')
|
||||
|
||||
def read_hex_binary(self):
|
||||
acc = bytearray()
|
||||
while True:
|
||||
self.skip_whitespace()
|
||||
if self.peek() == '"':
|
||||
self.skip()
|
||||
return bytes(acc)
|
||||
acc.append(self.hexnum(2))
|
||||
|
||||
def read_base64_binary(self):
|
||||
acc = []
|
||||
while True:
|
||||
self.skip_whitespace()
|
||||
c = self.nextchar()
|
||||
if c == ']':
|
||||
acc.append(u'====')
|
||||
return base64.b64decode(u''.join(acc))
|
||||
if c == '-': c = '+'
|
||||
if c == '_': c = '/'
|
||||
if c == '=': continue
|
||||
acc.append(c)
|
||||
|
||||
def upto(self, delimiter):
|
||||
vs = []
|
||||
while True:
|
||||
self.skip_whitespace()
|
||||
if self.peek() == delimiter:
|
||||
self.skip()
|
||||
return tuple(vs)
|
||||
vs.append(self.next())
|
||||
|
||||
def read_dictionary(self):
|
||||
acc = []
|
||||
while True:
|
||||
self.skip_whitespace()
|
||||
if self.peek() == '}':
|
||||
self.skip()
|
||||
return ImmutableDict.from_kvs(acc)
|
||||
acc.append(self.next())
|
||||
self.skip_whitespace()
|
||||
if self.nextchar() != ':':
|
||||
raise DecodeError('Missing expected key/value separator')
|
||||
acc.append(self.next())
|
||||
|
||||
def read_raw_symbol(self, acc):
|
||||
while not self._atend():
|
||||
c = self.peek()
|
||||
if c.isspace() or c in '(){}[]<>";,@#:|':
|
||||
break
|
||||
self.skip()
|
||||
acc.append(c)
|
||||
return Symbol(u''.join(acc))
|
||||
|
||||
def wrap(self, v):
|
||||
return Annotated(v) if self.include_annotations else v
|
||||
|
||||
def next(self):
|
||||
self.skip_whitespace()
|
||||
c = self.peek()
|
||||
if c == '-':
|
||||
self.skip()
|
||||
return self.wrap(self.read_intpart(['-'], self.nextchar()))
|
||||
if c.isdigit():
|
||||
self.skip()
|
||||
return self.wrap(self.read_intpart([], c))
|
||||
if c == '"':
|
||||
self.skip()
|
||||
return self.wrap(self.read_string('"'))
|
||||
if c == '|':
|
||||
self.skip()
|
||||
return self.wrap(Symbol(self.read_string('|')))
|
||||
if c in ';@':
|
||||
annotations = self.gather_annotations()
|
||||
v = self.next()
|
||||
if self.include_annotations:
|
||||
v.annotations = annotations + v.annotations
|
||||
return v
|
||||
if c == ':':
|
||||
raise DecodeError('Unexpected key/value separator between items')
|
||||
if c == '#':
|
||||
self.skip()
|
||||
c = self.nextchar()
|
||||
if c == 'f': return self.wrap(False)
|
||||
if c == 't': return self.wrap(True)
|
||||
if c == '{': return self.wrap(frozenset(self.upto('}')))
|
||||
if c == '"': return self.wrap(self.read_literal_binary())
|
||||
if c == 'x':
|
||||
if self.nextchar() != '"':
|
||||
raise DecodeError('Expected open-quote at start of hex ByteString')
|
||||
return self.wrap(self.read_hex_binary())
|
||||
if c == '[': return self.wrap(self.read_base64_binary())
|
||||
if c == '=':
|
||||
old_ann = self.include_annotations
|
||||
self.include_annotations = True
|
||||
bs_val = self.next()
|
||||
self.include_annotations = old_ann
|
||||
if len(bs_val.annotations) > 0:
|
||||
raise DecodeError('Annotations not permitted after #=')
|
||||
bs_val = bs_val.item
|
||||
if not isinstance(bs_val, bytes):
|
||||
raise DecodeError('ByteString must follow #=')
|
||||
return self.wrap(Decoder(bs_val, include_annotations = self.include_annotations).next())
|
||||
if c == '!':
|
||||
return self.wrap(self.parse_embedded(self.next()))
|
||||
raise DecodeError('Invalid # syntax')
|
||||
if c == '<':
|
||||
self.skip()
|
||||
vs = self.upto('>')
|
||||
if len(vs) == 0:
|
||||
raise DecodeError('Missing record label')
|
||||
return self.wrap(Record(vs[0], vs[1:]))
|
||||
if c == '[':
|
||||
self.skip()
|
||||
return self.wrap(self.upto(']'))
|
||||
if c == '{':
|
||||
self.skip()
|
||||
return self.wrap(self.read_dictionary())
|
||||
if c in '>]}':
|
||||
raise DecodeError('Unexpected ' + c)
|
||||
self.skip()
|
||||
return self.wrap(self.read_raw_symbol([c]))
|
||||
|
||||
def try_next(self):
|
||||
start = self.index
|
||||
try:
|
||||
return self.next()
|
||||
except ShortPacket:
|
||||
self.index = start
|
||||
return None
|
||||
|
||||
def parse(bs, **kwargs):
|
||||
return Parser(input_buffer=bs, **kwargs).next()
|
||||
|
||||
def parse_with_annotations(bs, **kwargs):
|
||||
return Parser(input_buffer=bs, include_annotations=True, **kwargs).next()
|
||||
|
||||
class Formatter(TextCodec):
|
||||
def __init__(self, format_embedded=None):
|
||||
super(Formatter, self).__init__()
|
||||
self.chunks = []
|
||||
self.format_embedded = format_embedded
|
||||
|
||||
def contents(self):
|
||||
return u''.join(self.chunks)
|
||||
|
||||
def write_stringlike_char(self, c):
|
||||
if c == '\\': self.chunks.append('\\\\')
|
||||
elif c == '\x08': self.chunks.append('\\b')
|
||||
elif c == '\x0c': self.chunks.append('\\f')
|
||||
elif c == '\x0a': self.chunks.append('\\n')
|
||||
elif c == '\x0d': self.chunks.append('\\r')
|
||||
elif c == '\x09': self.chunks.append('\\t')
|
||||
else: self.chunks.append(c)
|
||||
|
||||
def write_seq(self, opener, closer, vs):
|
||||
self.chunks.append(opener)
|
||||
first_item = True
|
||||
for v in vs:
|
||||
if first_item:
|
||||
first_item = False
|
||||
else:
|
||||
self.chunks.append(' ')
|
||||
self.append(v)
|
||||
self.chunks.append(closer)
|
||||
|
||||
def append(self, v):
|
||||
while hasattr(v, '__preserve__'):
|
||||
v = v.__preserve__()
|
||||
|
||||
if hasattr(v, '__preserve_write_text__'):
|
||||
v.__preserve_write_text__(self)
|
||||
elif v is False:
|
||||
self.chunks.append('#f')
|
||||
elif v is True:
|
||||
self.chunks.append('#t')
|
||||
elif isinstance(v, float):
|
||||
self.chunks.append(repr(v))
|
||||
elif isinstance(v, numbers.Number):
|
||||
self.chunks.append('%d' % (v,))
|
||||
elif isinstance(v, bytes):
|
||||
self.chunks.append('#[%s]' % (base64.b64encode(v).decode('ascii'),))
|
||||
elif isinstance(v, basestring_):
|
||||
self.chunks.append('"')
|
||||
for c in v:
|
||||
if c == '"': self.chunks.append('\\"')
|
||||
else: self.write_stringlike_char(c)
|
||||
self.chunks.append('"')
|
||||
elif isinstance(v, list):
|
||||
self.write_seq('[', ']', v)
|
||||
elif isinstance(v, tuple):
|
||||
self.write_seq('[', ']', v)
|
||||
elif isinstance(v, set):
|
||||
self.write_seq('#{', '}', v)
|
||||
elif isinstance(v, frozenset):
|
||||
self.write_seq('#{', '}', v)
|
||||
elif isinstance(v, dict):
|
||||
self.chunks.append('{')
|
||||
need_comma = False
|
||||
for (k, v) in v.items():
|
||||
if need_comma:
|
||||
self.chunks.append(', ')
|
||||
else:
|
||||
need_comma = True
|
||||
self.append(k)
|
||||
self.chunks.append(': ')
|
||||
self.append(v)
|
||||
self.chunks.append('}')
|
||||
else:
|
||||
try:
|
||||
i = iter(v)
|
||||
except TypeError:
|
||||
self.chunks.append('#!')
|
||||
self.append(self.format_embedded(v))
|
||||
return
|
||||
self.write_seq('[', ']', i)
|
||||
|
||||
def stringify(v, **kwargs):
|
||||
e = Formatter(**kwargs)
|
||||
e.append(v)
|
||||
return e.contents()
|
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
import sys
|
||||
import struct
|
||||
|
||||
|
@ -8,6 +9,7 @@ class Float(object):
|
|||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
if other.__class__ is self.__class__:
|
||||
return self.value == other.value
|
||||
|
||||
|
@ -20,15 +22,22 @@ class Float(object):
|
|||
def __repr__(self):
|
||||
return 'Float(' + repr(self.value) + ')'
|
||||
|
||||
def __preserve_on__(self, encoder):
|
||||
def __preserve_write_binary__(self, encoder):
|
||||
encoder.buffer.append(0x82)
|
||||
encoder.buffer.extend(struct.pack('>f', self.value))
|
||||
|
||||
def __preserve_write_text__(self, formatter):
|
||||
formatter.chunks.append(repr(self.value) + 'f')
|
||||
|
||||
# FIXME: This regular expression is conservatively correct, but Anglo-chauvinistic.
|
||||
RAW_SYMBOL_RE = re.compile(r'^[a-zA-Z~!$%^&*?_=+/.][-a-zA-Z~!$%^&*?_=+/.0-9]*$')
|
||||
|
||||
class Symbol(object):
|
||||
def __init__(self, name):
|
||||
self.name = name.name if isinstance(name, Symbol) else name
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
return isinstance(other, Symbol) and self.name == other.name
|
||||
|
||||
def __ne__(self, other):
|
||||
|
@ -40,12 +49,22 @@ class Symbol(object):
|
|||
def __repr__(self):
|
||||
return '#' + self.name
|
||||
|
||||
def __preserve_on__(self, encoder):
|
||||
def __preserve_write_binary__(self, encoder):
|
||||
bs = self.name.encode('utf-8')
|
||||
encoder.buffer.append(0xb3)
|
||||
encoder.varint(len(bs))
|
||||
encoder.buffer.extend(bs)
|
||||
|
||||
def __preserve_write_text__(self, formatter):
|
||||
if RAW_SYMBOL_RE.match(self.name):
|
||||
formatter.chunks.append(self.name)
|
||||
else:
|
||||
formatter.chunks.append('|')
|
||||
for c in self.name:
|
||||
if c == '|': formatter.chunks.append('\\|')
|
||||
else: formatter.write_stringlike_char(c)
|
||||
formatter.chunks.append('|')
|
||||
|
||||
class Record(object):
|
||||
def __init__(self, key, fields):
|
||||
self.key = key
|
||||
|
@ -53,6 +72,7 @@ class Record(object):
|
|||
self.__hash = None
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
return isinstance(other, Record) and (self.key, self.fields) == (other.key, other.fields)
|
||||
|
||||
def __ne__(self, other):
|
||||
|
@ -66,13 +86,16 @@ class Record(object):
|
|||
def __repr__(self):
|
||||
return str(self.key) + '(' + ', '.join((repr(f) for f in self.fields)) + ')'
|
||||
|
||||
def __preserve_on__(self, encoder):
|
||||
def __preserve_write_binary__(self, encoder):
|
||||
encoder.buffer.append(0xb4)
|
||||
encoder.append(self.key)
|
||||
for f in self.fields:
|
||||
encoder.append(f)
|
||||
encoder.buffer.append(0x84)
|
||||
|
||||
def __preserve_write_text__(self, formatter):
|
||||
formatter.write_seq('<', '>', (self.key,) + self.fields)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.fields[index]
|
||||
|
||||
|
@ -114,6 +137,7 @@ class RecordConstructorInfo(object):
|
|||
self.arity = arity
|
||||
|
||||
def __eq__(self, other):
|
||||
other = _unwrap(other)
|
||||
return isinstance(other, RecordConstructorInfo) and \
|
||||
(self.key, self.arity) == (other.key, other.arity)
|
||||
|
||||
|
@ -180,12 +204,19 @@ class Annotated(object):
|
|||
self.annotations = []
|
||||
self.item = item
|
||||
|
||||
def __preserve_on__(self, encoder):
|
||||
def __preserve_write_binary__(self, encoder):
|
||||
for a in self.annotations:
|
||||
encoder.buffer.append(0x85)
|
||||
encoder.append(a)
|
||||
encoder.append(self.item)
|
||||
|
||||
def __preserve_write_text__(self, formatter):
|
||||
for a in self.annotations:
|
||||
formatter.chunks.append('@')
|
||||
formatter.append(a)
|
||||
formatter.chunks.append(' ')
|
||||
formatter.append(self.item)
|
||||
|
||||
def strip(self, depth=inf):
|
||||
return strip_annotations(self, depth)
|
||||
|
||||
|
@ -193,8 +224,7 @@ class Annotated(object):
|
|||
return strip_annotations(self, 1)
|
||||
|
||||
def __eq__(self, other):
|
||||
if other.__class__ is self.__class__:
|
||||
return self.item == other.item
|
||||
return self.item == _unwrap(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
@ -240,3 +270,9 @@ def annotate(v, *anns):
|
|||
for a in anns:
|
||||
v.annotations.append(a)
|
||||
return v
|
||||
|
||||
def _unwrap(x):
|
||||
if is_annotated(x):
|
||||
return x.item
|
||||
else:
|
||||
return x
|
Loading…
Reference in New Issue