patterns: sort dictionary keys during analysis

This commit is contained in:
Emery Hemingway 2023-07-23 08:30:45 +01:00
parent 7b2d59e4cd
commit 16cc5aaf98
2 changed files with 28 additions and 3 deletions

View File

@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: ☭ Emery Hemingway
# SPDX-License-Identifier: Unlicense
import std/[options, sequtils, tables, typetraits]
import std/[algorithm, options, sequtils, tables, typetraits]
import preserves
import ./protocols/dataspacePatterns
@ -10,6 +10,7 @@ from ./actors import Ref
export dataspacePatterns.`$`, PatternKind, DCompoundKind, AnyAtomKind
type
Value = Preserve[Ref]
AnyAtom = dataspacePatterns.AnyAtom[Ref]
DBind = dataspacePatterns.DBind[Ref]
DCompound = dataspacePatterns.DCompound[Ref]
@ -19,6 +20,17 @@ type
DLit = dataspacePatterns.DLit[Ref]
Pattern* = dataspacePatterns.Pattern[Ref]
iterator orderedEntries*(dict: DCompoundDict): (Value, Pattern) =
## Iterate a `DCompoundDict` in Preserves order.
## Values captured from a dictionary are represented as an
## array of values ordered by their former key, so using an
## ordered iterator is sometimes essential.
var keys = dict.entries.keys.toSeq
sort(keys, preserves.cmp)
for k in keys:
yield(k, dict.entries.getOrDefault(k))
# getOrDefault doesn't raise and we know the keys will match
proc toPattern(d: sink DBind): Pattern =
Pattern(orKind: PatternKind.DBind, dbind: d)
@ -275,7 +287,6 @@ proc recordPattern*(label: Preserve[Ref], fields: varargs[Pattern]): Pattern =
DCompoundRec(label: label, fields: fields.toSeq).toPattern
type
Value = Preserve[Ref]
Path* = seq[Value]
Paths* = seq[Path]
Captures* = seq[Value]
@ -300,7 +311,7 @@ func walk(result: var Analysis; path: var Path; p: Pattern) =
of DCompoundKind.arr:
for k, e in p.dcompound.arr.items: walk(result, path, k, e)
of DCompoundKind.dict:
for k, e in p.dcompound.dict.entries: walk(result, path, k, e)
for k, e in p.dcompound.dict.orderedEntries: walk(result, path, k, e)
of PatternKind.DBind:
result.capturePaths.add(path)
walk(result, path, p.dbind.pattern)

View File

@ -20,3 +20,17 @@ test "patterns":
have = capture(observerPat, observer).toPreserve(Ref).unpackLiterals
want = [value.toPreserve(Ref)].toPreserve(Ref)
check(have == want)
type Record {.preservesDictionary.} = object
a, b, c: int
test "dictionaries":
let pat = ?Record
echo pat
var source = initDictionary(Ref)
source["b".toSymbol(Ref)] = 2.toPreserve(Ref)
source["c".toSymbol(Ref)] = 3.toPreserve(Ref)
source["a".toSymbol(Ref)] = 1.toPreserve(Ref)
let values = capture(pat, source)
check $values == "@[1, 2, 3]"