From e01914806532370741ed2404f37058c8b0361ab8 Mon Sep 17 00:00:00 2001 From: Tony Garnock-Jones Date: Tue, 7 Jun 2022 21:45:56 +0200 Subject: [PATCH] Fix positional initializers for schema values, and allow keyword initializers --- implementations/python/preserves/schema.py | 37 +++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/implementations/python/preserves/schema.py b/implementations/python/preserves/schema.py index a94f82c..771b84c 100644 --- a/implementations/python/preserves/schema.py +++ b/implementations/python/preserves/schema.py @@ -128,7 +128,7 @@ class SchemaObject: if p.key == DICT: if not isinstance(v, dict): return None if len(v) < len(p[0]): return None - for (k, pp) in p[0].items(): + for (k, pp) in compare.sorted_items(p[0]): if k not in v: return None if cls.parse(pp, v[k], args) is None: return None return () @@ -194,10 +194,14 @@ def safesetattr(o, k, v): def safegetattr(o, k): return getattr(o, safeattrname(k)) +def safehasattr(o, k): + return hasattr(o, safeattrname(k)) + class Definition(SchemaObject): EMPTY = False SIMPLE = False FIELD_NAMES = [] + SAFE_FIELD_NAMES = [] ENUMERATION = None def _constructor_name(self): @@ -206,7 +210,7 @@ class Definition(SchemaObject): else: return self.NAME.name + '.' + self.VARIANT.name - def __init__(self, *args): + def __init__(self, *args, **kwargs): self._fields = args if self.SIMPLE: if self.EMPTY: @@ -217,12 +221,21 @@ class Definition(SchemaObject): raise TypeError('%s needs exactly one argument' % (self._constructor_name(),)) self.value = args[0] else: - if len(args) != len(self.FIELD_NAMES): - raise TypeError('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES)) i = 0 - for k in self.FIELD_NAMES: - safesetattr(self, k, args[i]) + for arg in args: + if i >= len(self.FIELD_NAMES): + raise TypeError('%s given too many positional arguments' % (self._constructor_name(),)) + setattr(self, self.SAFE_FIELD_NAMES[i], arg) i = i + 1 + for (argname, arg) in kwargs.items(): + if hasattr(self, argname): + raise TypeError('%s given duplicate attribute: %r' % (self._constructor_name, argname)) + if argname not in self.SAFE_FIELD_NAMES: + raise TypeError('%s given unknown attribute: %r' % (self._constructor_name, argname)) + setattr(self, argname, arg) + i = i + 1 + if i != len(self.FIELD_NAMES): + raise TypeError('%s needs argument(s) %r' % (self._constructor_name(), self.FIELD_NAMES)) def __eq__(self, other): return (other.__class__ is self.__class__) and (self._fields == other._fields) @@ -251,6 +264,7 @@ class Definition(SchemaObject): cls.VARIANT = variant cls.ENUMERATION = enumeration gather_defined_field_names(schema, cls.FIELD_NAMES) + cls.SAFE_FIELD_NAMES = [safeattrname(n) for n in cls.FIELD_NAMES] @classmethod def try_decode(cls, v): @@ -263,7 +277,8 @@ class Definition(SchemaObject): return cls(i) else: args = [] - if cls.parse(cls.SCHEMA, v, args) is not None: return cls(*args) + if cls.parse(cls.SCHEMA, v, args) is not None: + return cls(*args) return None def __preserve__(self): @@ -276,10 +291,10 @@ class Definition(SchemaObject): return encode(self.SCHEMA, self) def _as_dict(self): - return dict((k, getattr(self, k)) for k in self.FIELD_NAMES) + return dict((k, safegetattr(self, k)) for k in self.FIELD_NAMES) def __getitem__(self, name): - return getattr(self, name) + return safegetattr(self, name) def __setitem__(self, name, value): return safesetattr(self, name, value) @@ -353,7 +368,7 @@ def gather_defined_field_names(s, acc): gather_defined_field_names(s[0], acc) gather_defined_field_names(s[1], acc) elif s.key == DICT: - gather_defined_field_names(tuple(s[0].values()), acc) + gather_defined_field_names(tuple(item[1] for item in compare.sorted_items(s[0])), acc) else: raise ValueError('Bad schema') @@ -390,7 +405,7 @@ class Namespace: safesetattr(self, name, value) def __contains__(self, name): - return Symbol(name).name in self.__dict__ + return safeattrname(Symbol(name).name) in self.__dict__ def _items(self): return dict((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')