Small improvements

This commit is contained in:
Tony Garnock-Jones 2021-08-15 23:30:46 -04:00
parent abe60b3506
commit e45ff6b020
1 changed files with 37 additions and 6 deletions

View File

@ -1,5 +1,6 @@
from .preserves import *
import pathlib
import keyword
AND = Symbol('and')
ANY = Symbol('any')
@ -33,6 +34,7 @@ class SchemaEntity:
SCHEMA = None
MODULE_PATH = None
NAME = None
VARIANT = None
@classmethod
def decode(cls, v):
@ -132,9 +134,7 @@ class SchemaEntity:
raise Exception('Subclass responsibility')
def __repr__(self):
n = self.NAME.name
if self.VARIANT:
n = n + '.' + self.VARIANT.name
n = self._constructor_name()
if self.SIMPLE:
return n + '(' + repr(self.value) + ')'
else:
@ -161,6 +161,7 @@ class Enumeration(SchemaEntity):
c = pretty_subclass(Definition, module_path_str(module_path + (name,)), n.name)
c._set_schema(rootns, module_path, name, d, n, cls)
cls.VARIANTS.append((n, c))
safesetattr(cls, n.name, c)
@classmethod
def try_decode(cls, v):
@ -172,22 +173,46 @@ class Enumeration(SchemaEntity):
def _encode(self):
raise Exception('Cannot encode instance of Enumeration')
def safesetattr(o, k, v):
if keyword.iskeyword(k):
setattr(o, k + '_', v)
else:
setattr(o, k, v)
class Definition(SchemaEntity):
SIMPLE = False
FIELD_NAMES = []
VARIANT = None
ENUMERATION = None
def _constructor_name(self):
if self.VARIANT is None:
return self.NAME.name
else:
return self.NAME.name + '.' + self.VARIANT.name
def __init__(self, *args):
self._fields = args
if self.SIMPLE:
if len(args) != 1:
raise Exception('%s needs exactly one argument' % (self._constructor_name(),))
self.value = args[0]
else:
if len(args) != len(self.FIELD_NAMES):
raise Exception('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES))
i = 0
for k in self.FIELD_NAMES:
setattr(self, k, args[i])
safesetattr(self, k, args[i])
i = i + 1
def __eq__(self, other):
return (other.__class__ is self.__class__) and (self._fields == other._fields)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self._fields) ^ hash(self.__class_)
def _accept(self, visitor):
if self.VARIANT is None:
return visitor(*self._fields)
@ -226,7 +251,7 @@ class Definition(SchemaEntity):
return getattr(self, name)
def __setitem__(self, name, value):
return setattr(self, name, value)
return safesetattr(self, name, value)
def module_path_str(mp):
return '.'.join([e.name for e in mp])
@ -343,6 +368,11 @@ if __name__ == '__main__':
x = Decoder(f.read()).next()
print(c.root.schema.Schema.decode(x))
def m(self, x):
return ['yay', self.embeddedType, x]
c.root.schema.Schema.f = m
print(c.root.schema.Schema.decode(x).f(123))
print()
d = Compiler()
@ -351,5 +381,6 @@ if __name__ == '__main__':
with open(path_bin_filename, 'rb') as f:
x = Decoder(f.read()).next()
print(c.root.schema.Schema.decode(x))
print(c.root.schema.Schema.decode(x) == c.root.schema.Schema.decode(x))
print()
print(d.root)