Fix positional initializers for schema values, and allow keyword initializers

This commit is contained in:
Tony Garnock-Jones 2022-06-07 21:45:56 +02:00
parent c9c973ce9c
commit e019148065
1 changed files with 26 additions and 11 deletions

View File

@ -128,7 +128,7 @@ class SchemaObject:
if p.key == DICT: if p.key == DICT:
if not isinstance(v, dict): return None if not isinstance(v, dict): return None
if len(v) < len(p[0]): return None if len(v) < len(p[0]): return None
for (k, pp) in p[0].items(): for (k, pp) in compare.sorted_items(p[0]):
if k not in v: return None if k not in v: return None
if cls.parse(pp, v[k], args) is None: return None if cls.parse(pp, v[k], args) is None: return None
return () return ()
@ -194,10 +194,14 @@ def safesetattr(o, k, v):
def safegetattr(o, k): def safegetattr(o, k):
return getattr(o, safeattrname(k)) return getattr(o, safeattrname(k))
def safehasattr(o, k):
return hasattr(o, safeattrname(k))
class Definition(SchemaObject): class Definition(SchemaObject):
EMPTY = False EMPTY = False
SIMPLE = False SIMPLE = False
FIELD_NAMES = [] FIELD_NAMES = []
SAFE_FIELD_NAMES = []
ENUMERATION = None ENUMERATION = None
def _constructor_name(self): def _constructor_name(self):
@ -206,7 +210,7 @@ class Definition(SchemaObject):
else: else:
return self.NAME.name + '.' + self.VARIANT.name return self.NAME.name + '.' + self.VARIANT.name
def __init__(self, *args): def __init__(self, *args, **kwargs):
self._fields = args self._fields = args
if self.SIMPLE: if self.SIMPLE:
if self.EMPTY: if self.EMPTY:
@ -217,12 +221,21 @@ class Definition(SchemaObject):
raise TypeError('%s needs exactly one argument' % (self._constructor_name(),)) raise TypeError('%s needs exactly one argument' % (self._constructor_name(),))
self.value = args[0] self.value = args[0]
else: else:
if len(args) != len(self.FIELD_NAMES):
raise TypeError('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES))
i = 0 i = 0
for k in self.FIELD_NAMES: for arg in args:
safesetattr(self, k, args[i]) if i >= len(self.FIELD_NAMES):
raise TypeError('%s given too many positional arguments' % (self._constructor_name(),))
setattr(self, self.SAFE_FIELD_NAMES[i], arg)
i = i + 1 i = i + 1
for (argname, arg) in kwargs.items():
if hasattr(self, argname):
raise TypeError('%s given duplicate attribute: %r' % (self._constructor_name, argname))
if argname not in self.SAFE_FIELD_NAMES:
raise TypeError('%s given unknown attribute: %r' % (self._constructor_name, argname))
setattr(self, argname, arg)
i = i + 1
if i != len(self.FIELD_NAMES):
raise TypeError('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES))
def __eq__(self, other): def __eq__(self, other):
return (other.__class__ is self.__class__) and (self._fields == other._fields) return (other.__class__ is self.__class__) and (self._fields == other._fields)
@ -251,6 +264,7 @@ class Definition(SchemaObject):
cls.VARIANT = variant cls.VARIANT = variant
cls.ENUMERATION = enumeration cls.ENUMERATION = enumeration
gather_defined_field_names(schema, cls.FIELD_NAMES) gather_defined_field_names(schema, cls.FIELD_NAMES)
cls.SAFE_FIELD_NAMES = [safeattrname(n) for n in cls.FIELD_NAMES]
@classmethod @classmethod
def try_decode(cls, v): def try_decode(cls, v):
@ -263,7 +277,8 @@ class Definition(SchemaObject):
return cls(i) return cls(i)
else: else:
args = [] args = []
if cls.parse(cls.SCHEMA, v, args) is not None: return cls(*args) if cls.parse(cls.SCHEMA, v, args) is not None:
return cls(*args)
return None return None
def __preserve__(self): def __preserve__(self):
@ -276,10 +291,10 @@ class Definition(SchemaObject):
return encode(self.SCHEMA, self) return encode(self.SCHEMA, self)
def _as_dict(self): def _as_dict(self):
return dict((k, getattr(self, k)) for k in self.FIELD_NAMES) return dict((k, safegetattr(self, k)) for k in self.FIELD_NAMES)
def __getitem__(self, name): def __getitem__(self, name):
return getattr(self, name) return safegetattr(self, name)
def __setitem__(self, name, value): def __setitem__(self, name, value):
return safesetattr(self, name, value) return safesetattr(self, name, value)
@ -353,7 +368,7 @@ def gather_defined_field_names(s, acc):
gather_defined_field_names(s[0], acc) gather_defined_field_names(s[0], acc)
gather_defined_field_names(s[1], acc) gather_defined_field_names(s[1], acc)
elif s.key == DICT: elif s.key == DICT:
gather_defined_field_names(tuple(s[0].values()), acc) gather_defined_field_names(tuple(item[1] for item in compare.sorted_items(s[0])), acc)
else: else:
raise ValueError('Bad schema') raise ValueError('Bad schema')
@ -390,7 +405,7 @@ class Namespace:
safesetattr(self, name, value) safesetattr(self, name, value)
def __contains__(self, name): def __contains__(self, name):
return Symbol(name).name in self.__dict__ return safeattrname(Symbol(name).name) in self.__dict__
def _items(self): def _items(self):
return dict((k, v) for (k, v) in self.__dict__.items() if k[0] != '_') return dict((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')