diff --git a/implementations/python/preserves/schema.py b/implementations/python/preserves/schema.py index 37625b6..5b52b52 100644 --- a/implementations/python/preserves/schema.py +++ b/implementations/python/preserves/schema.py @@ -136,7 +136,10 @@ class SchemaEntity: def __repr__(self): n = self._constructor_name() if self.SIMPLE: - return n + '(' + repr(self.value) + ')' + if self.EMPTY: + return n + '()' + else: + return n + '(' + repr(self.value) + ')' else: return n + ' ' + repr(self._as_dict()) @@ -180,6 +183,7 @@ def safesetattr(o, k, v): setattr(o, k, v) class Definition(SchemaEntity): + EMPTY = False SIMPLE = False FIELD_NAMES = [] ENUMERATION = None @@ -193,9 +197,13 @@ class Definition(SchemaEntity): def __init__(self, *args): self._fields = args if self.SIMPLE: - if len(args) != 1: - raise Exception('%s needs exactly one argument' % (self._constructor_name(),)) - self.value = args[0] + if self.EMPTY: + if len(args) != 0: + raise Exception('%s takes no arguments' % (self._constructor_name(),)) + else: + if len(args) != 1: + raise Exception('%s needs exactly one argument' % (self._constructor_name(),)) + self.value = args[0] else: if len(args) != len(self.FIELD_NAMES): raise Exception('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES)) @@ -225,6 +233,7 @@ class Definition(SchemaEntity): cls.SCHEMA = schema cls.MODULE_PATH = module_path cls.NAME = name + cls.EMPTY = is_empty_pattern(schema) cls.SIMPLE = is_simple_pattern(schema) cls.FIELD_NAMES = [] cls.VARIANT = variant @@ -235,7 +244,11 @@ class Definition(SchemaEntity): def try_decode(cls, v): if cls.SIMPLE: i = cls.parse(cls.SCHEMA, v, []) - if i is not None: return cls(i) + if i is not None: + if cls.EMPTY: + return cls() + else: + return cls(i) else: args = [] if cls.parse(cls.SCHEMA, v, args) is not None: return cls(*args) @@ -260,6 +273,9 @@ SIMPLE_PATTERN_KEYS = [ATOM, EMBEDDED, LIT, SEQOF, SETOF, DICTOF, REF] def is_simple_pattern(p): return p == ANY or (isinstance(p, Record) and p.key in SIMPLE_PATTERN_KEYS) +def is_empty_pattern(p): + return isinstance(p, Record) and p.key == LIT + def gather_defined_field_names(s, acc): if is_simple_pattern(s): pass