From e45ff6b02051c4c37e839b23a932ac0825f4daee Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Sun, 15 Aug 2021 23:30:46 -0400 Subject: [PATCH] Small improvements --- implementations/python/preserves/schema.py | 43 +++++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/implementations/python/preserves/schema.py b/implementations/python/preserves/schema.py index 61e4c51..e6d0d5d 100644 --- a/implementations/python/preserves/schema.py +++ b/implementations/python/preserves/schema.py @@ -1,5 +1,6 @@ from .preserves import * import pathlib +import keyword AND = Symbol('and') ANY = Symbol('any') @@ -33,6 +34,7 @@ class SchemaEntity: SCHEMA = None MODULE_PATH = None NAME = None + VARIANT = None @classmethod def decode(cls, v): @@ -132,9 +134,7 @@ class SchemaEntity: raise Exception('Subclass responsibility') def __repr__(self): - n = self.NAME.name - if self.VARIANT: - n = n + '.' + self.VARIANT.name + n = self._constructor_name() if self.SIMPLE: return n + '(' + repr(self.value) + ')' else: @@ -161,6 +161,7 @@ class Enumeration(SchemaEntity): c = pretty_subclass(Definition, module_path_str(module_path + (name,)), n.name) c._set_schema(rootns, module_path, name, d, n, cls) cls.VARIANTS.append((n, c)) + safesetattr(cls, n.name, c) @classmethod def try_decode(cls, v): @@ -172,22 +173,46 @@ class Enumeration(SchemaEntity): def _encode(self): raise Exception('Cannot encode instance of Enumeration') +def safesetattr(o, k, v): + if keyword.iskeyword(k): + setattr(o, k + '_', v) + else: + setattr(o, k, v) + class Definition(SchemaEntity): SIMPLE = False FIELD_NAMES = [] - VARIANT = None ENUMERATION = None + def _constructor_name(self): + if self.VARIANT is None: + return self.NAME.name + else: + return self.NAME.name + '.' + self.VARIANT.name + def __init__(self, *args): self._fields = args if self.SIMPLE: + if len(args) != 1: + raise Exception('%s needs exactly one argument' % (self._constructor_name(),)) self.value = args[0] else: + if len(args) != len(self.FIELD_NAMES): + raise Exception('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES)) i = 0 for k in self.FIELD_NAMES: - setattr(self, k, args[i]) + safesetattr(self, k, args[i]) i = i + 1 + def __eq__(self, other): + return (other.__class__ is self.__class__) and (self._fields == other._fields) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._fields) ^ hash(self.__class_) + def _accept(self, visitor): if self.VARIANT is None: return visitor(*self._fields) @@ -226,7 +251,7 @@ class Definition(SchemaEntity): return getattr(self, name) def __setitem__(self, name, value): - return setattr(self, name, value) + return safesetattr(self, name, value) def module_path_str(mp): return '.'.join([e.name for e in mp]) @@ -343,6 +368,11 @@ if __name__ == '__main__': x = Decoder(f.read()).next() print(c.root.schema.Schema.decode(x)) + def m(self, x): + return ['yay', self.embeddedType, x] + c.root.schema.Schema.f = m + print(c.root.schema.Schema.decode(x).f(123)) + print() d = Compiler() @@ -351,5 +381,6 @@ if __name__ == '__main__': with open(path_bin_filename, 'rb') as f: x = Decoder(f.read()).next() print(c.root.schema.Schema.decode(x)) + print(c.root.schema.Schema.decode(x) == c.root.schema.Schema.decode(x)) print() print(d.root)