Handle empty patterns specially

This commit is contained in:
Tony Garnock-Jones 2021-08-15 23:51:07 -04:00
parent 59bcced776
commit fc1d6afc28
1 changed files with 21 additions and 5 deletions

View File

@ -136,7 +136,10 @@ class SchemaEntity:
def __repr__(self):
n = self._constructor_name()
if self.SIMPLE:
return n + '(' + repr(self.value) + ')'
if self.EMPTY:
return n + '()'
else:
return n + '(' + repr(self.value) + ')'
else:
return n + ' ' + repr(self._as_dict())
@ -180,6 +183,7 @@ def safesetattr(o, k, v):
setattr(o, k, v)
class Definition(SchemaEntity):
EMPTY = False
SIMPLE = False
FIELD_NAMES = []
ENUMERATION = None
@ -193,9 +197,13 @@ class Definition(SchemaEntity):
def __init__(self, *args):
self._fields = args
if self.SIMPLE:
if len(args) != 1:
raise Exception('%s needs exactly one argument' % (self._constructor_name(),))
self.value = args[0]
if self.EMPTY:
if len(args) != 0:
raise Exception('%s takes no arguments' % (self._constructor_name(),))
else:
if len(args) != 1:
raise Exception('%s needs exactly one argument' % (self._constructor_name(),))
self.value = args[0]
else:
if len(args) != len(self.FIELD_NAMES):
raise Exception('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES))
@ -225,6 +233,7 @@ class Definition(SchemaEntity):
cls.SCHEMA = schema
cls.MODULE_PATH = module_path
cls.NAME = name
cls.EMPTY = is_empty_pattern(schema)
cls.SIMPLE = is_simple_pattern(schema)
cls.FIELD_NAMES = []
cls.VARIANT = variant
@ -235,7 +244,11 @@ class Definition(SchemaEntity):
def try_decode(cls, v):
if cls.SIMPLE:
i = cls.parse(cls.SCHEMA, v, [])
if i is not None: return cls(i)
if i is not None:
if cls.EMPTY:
return cls()
else:
return cls(i)
else:
args = []
if cls.parse(cls.SCHEMA, v, args) is not None: return cls(*args)
@ -260,6 +273,9 @@ SIMPLE_PATTERN_KEYS = [ATOM, EMBEDDED, LIT, SEQOF, SETOF, DICTOF, REF]
def is_simple_pattern(p):
return p == ANY or (isinstance(p, Record) and p.key in SIMPLE_PATTERN_KEYS)
def is_empty_pattern(p):
return isinstance(p, Record) and p.key == LIT
def gather_defined_field_names(s, acc):
if is_simple_pattern(s):
pass