Update for new dataspace pattern language

This commit is contained in:
Tony Garnock-Jones 2024-04-09 15:06:08 +02:00
parent 6de5e96aa1
commit d9e1be2e98
1 changed files with 36 additions and 16 deletions

View File

@ -4,10 +4,10 @@ from preserves import preserve
_dict = dict ## we're about to shadow the builtin
_ = P.Pattern.DDiscard(P.DDiscard())
_ = P.Pattern.discard()
def bind(p):
return P.Pattern.DBind(P.DBind(p))
return P.Pattern.bind(p)
CAPTURE = bind(_)
@ -48,7 +48,7 @@ def quote(p):
elif isinstance(p, Record):
return _rec(p.key, *map(quote, p.fields))
else:
return P.Pattern.DLit(P.DLit(P.AnyAtom.decode(p)))
return P.Pattern.lit(P.AnyAtom.decode(p))
def lit(v):
if isinstance(v, list) or isinstance(v, tuple):
@ -60,21 +60,41 @@ def lit(v):
elif isinstance(v, Record):
return _rec(v.key, *map(lit, v.fields))
else:
return P.Pattern.DLit(P.DLit(P.AnyAtom.decode(v)))
return P.Pattern.lit(P.AnyAtom.decode(v))
def seq_entries(seq):
entries = {}
for i, p in enumerate(seq):
if p.VARIANT != P.Pattern.discard.VARIANT:
entries[i] = p
np = len(seq)
if np > 0 and (np - 1) not in entries:
entries[np - 1] = P.Pattern.discard()
return entries
def unlit_seq(entries):
seq = []
if len(entries) > 0:
try:
max_k = max(entries.keys())
except TypeError:
raise Exception('Pattern entries do not represent a gap-free sequence')
for i in range(max_k + 1):
seq.append(unlit(entries[i]))
return seq
def unlit(p):
if not hasattr(p, 'VARIANT'):
p = P.Pattern.decode(p)
if p.VARIANT == P.Pattern.DLit.VARIANT:
return p.value.value.value
if p.VARIANT != P.Pattern.DCompound.VARIANT:
if p.VARIANT == P.Pattern.lit.VARIANT:
return p.value.value
if p.VARIANT != P.Pattern.group.VARIANT:
raise Exception('Pattern does not represent a literal value')
p = p.value
if p.VARIANT == P.DCompound.rec.VARIANT:
return Record(p.label, map(unlit, p.fields))
if p.VARIANT == P.DCompound.arr.VARIANT:
return list(map(unlit, p.items))
if p.VARIANT == P.DCompound.dict.VARIANT:
if p.type.VARIANT == P.GroupType.rec.VARIANT:
return Record(p.type.label, unlit_seq(p.entries))
if p.type.VARIANT == P.GroupType.arr.VARIANT:
return list(unlit_seq(p.entries))
if p.type.VARIANT == P.GroupType.dict.VARIANT:
return _dict(map(lambda kv: (kv[0], unlit(kv[1])), p.entries.items()))
raise Exception('unreachable')
@ -82,10 +102,10 @@ def rec(labelstr, *members):
return _rec(Symbol(labelstr), *members)
def _rec(label, *members):
return P.Pattern.DCompound(P.DCompound.rec(label, members))
return P.Pattern.group(P.GroupType.rec(label), seq_entries(members))
def arr(*members):
return P.Pattern.DCompound(P.DCompound.arr(members))
return P.Pattern.group(P.GroupType.arr(), seq_entries(members))
def dict(*kvs):
return P.Pattern.DCompound(P.DCompound.dict(_dict(kvs)))
return P.Pattern.group(P.GroupType.dict(), _dict(kvs))