diff --git a/syndicate/patterns.py b/syndicate/patterns.py index a617a99..91807f9 100644 --- a/syndicate/patterns.py +++ b/syndicate/patterns.py @@ -14,6 +14,8 @@ CAPTURE = bind(_) class unquote: def __init__(self, pattern): self.pattern = pattern + def __escape_schema__(self): + return self uCAPTURE = unquote(CAPTURE) u_ = unquote(_) @@ -24,7 +26,7 @@ u_ = unquote(_) # # then these all produce the same pattern: # -# P.rec('Observe', P.quote(P.rec('run', P.lit('N'), P.unquote(P.CAPTURE), P.bind(P.unquote(P._)))), P._) +# P.rec('Observe', P.quote(P.rec('run', P.lit('N'), P.uCAPTURE, P.bind(P.u_))), P._) # # P.rec('Observe', P.quote(P.quote(Run('N', P.unquote(P.uCAPTURE), P.unquote(P.bind(P.u_))))), P._) # @@ -60,6 +62,22 @@ def lit(v): else: return P.Pattern.DLit(P.DLit(P.AnyAtom.decode(v))) +def unlit(p): + if not hasattr(p, 'VARIANT'): + p = P.Pattern.decode(p) + if p.VARIANT == P.Pattern.DLit.VARIANT: + return p.value.value.value + if p.VARIANT != P.Pattern.DCompound.VARIANT: + raise Exception('Pattern does not represent a literal value') + p = p.value + if p.VARIANT == P.DCompound.rec.VARIANT: + return Record(p.label, map(unlit, p.fields)) + if p.VARIANT == P.DCompound.arr.VARIANT: + return list(map(unlit, p.items)) + if p.VARIANT == P.DCompound.dict.VARIANT: + return _dict(map(lambda kv: (kv[0], unlit(kv[1])), p.entries.items())) + raise Exception('unreachable') + def rec(labelstr, *members): return _rec(Symbol(labelstr), *members)