diff --git a/syndicate/patterns.py b/syndicate/patterns.py index 91807f9..002431e 100644 --- a/syndicate/patterns.py +++ b/syndicate/patterns.py @@ -4,10 +4,10 @@ from preserves import preserve _dict = dict ## we're about to shadow the builtin -_ = P.Pattern.DDiscard(P.DDiscard()) +_ = P.Pattern.discard() def bind(p): - return P.Pattern.DBind(P.DBind(p)) + return P.Pattern.bind(p) CAPTURE = bind(_) @@ -48,7 +48,7 @@ def quote(p): elif isinstance(p, Record): return _rec(p.key, *map(quote, p.fields)) else: - return P.Pattern.DLit(P.DLit(P.AnyAtom.decode(p))) + return P.Pattern.lit(P.AnyAtom.decode(p)) def lit(v): if isinstance(v, list) or isinstance(v, tuple): @@ -60,21 +60,41 @@ def lit(v): elif isinstance(v, Record): return _rec(v.key, *map(lit, v.fields)) else: - return P.Pattern.DLit(P.DLit(P.AnyAtom.decode(v))) + return P.Pattern.lit(P.AnyAtom.decode(v)) + +def seq_entries(seq): + entries = {} + for i, p in enumerate(seq): + if p.VARIANT != P.Pattern.discard.VARIANT: + entries[i] = p + np = len(seq) + if np > 0 and (np - 1) not in entries: + entries[np - 1] = P.Pattern.discard() + return entries + +def unlit_seq(entries): + seq = [] + if len(entries) > 0: + try: + max_k = max(entries.keys()) + except TypeError: + raise Exception('Pattern entries do not represent a gap-free sequence') + for i in range(max_k + 1): + seq.append(unlit(entries[i])) + return seq def unlit(p): if not hasattr(p, 'VARIANT'): p = P.Pattern.decode(p) - if p.VARIANT == P.Pattern.DLit.VARIANT: - return p.value.value.value - if p.VARIANT != P.Pattern.DCompound.VARIANT: + if p.VARIANT == P.Pattern.lit.VARIANT: + return p.value.value + if p.VARIANT != P.Pattern.group.VARIANT: raise Exception('Pattern does not represent a literal value') - p = p.value - if p.VARIANT == P.DCompound.rec.VARIANT: - return Record(p.label, map(unlit, p.fields)) - if p.VARIANT == P.DCompound.arr.VARIANT: - return list(map(unlit, p.items)) - if p.VARIANT == P.DCompound.dict.VARIANT: + if p.type.VARIANT == P.GroupType.rec.VARIANT: + return Record(p.type.label, unlit_seq(p.entries)) + if p.type.VARIANT == P.GroupType.arr.VARIANT: + return list(unlit_seq(p.entries)) + if p.type.VARIANT == P.GroupType.dict.VARIANT: return _dict(map(lambda kv: (kv[0], unlit(kv[1])), p.entries.items())) raise Exception('unreachable') @@ -82,10 +102,10 @@ def rec(labelstr, *members): return _rec(Symbol(labelstr), *members) def _rec(label, *members): - return P.Pattern.DCompound(P.DCompound.rec(label, members)) + return P.Pattern.group(P.GroupType.rec(label), seq_entries(members)) def arr(*members): - return P.Pattern.DCompound(P.DCompound.arr(members)) + return P.Pattern.group(P.GroupType.arr(), seq_entries(members)) def dict(*kvs): - return P.Pattern.DCompound(P.DCompound.dict(_dict(kvs))) + return P.Pattern.group(P.GroupType.dict(), _dict(kvs))