preserves_schema_nim: better recursive type detection

This commit is contained in:
Emery Hemingway 2023-12-29 13:47:25 +02:00
parent c01e587e5b
commit 97ab7ce070
3 changed files with 117 additions and 45 deletions

View File

@ -20,10 +20,7 @@ import ../preserves, ./schema
type type
Attribute = enum Attribute = enum
embedded embedded
## type contains an embedded value and ## type contains an embedded value
## must take an parameter
recursive
## type is recursive and therefore must be a ref
Attributes = set[Attribute] Attributes = set[Attribute]
TypeSpec = object TypeSpec = object
node: PNode node: PNode
@ -119,12 +116,15 @@ proc ident(`ref`: Ref): PNode =
dotExtend(result, `ref`.name.string.capitalizeAscii) dotExtend(result, `ref`.name.string.capitalizeAscii)
proc deref(loc: Location; r: Ref): (Location, Definition) = proc deref(loc: Location; r: Ref): (Location, Definition) =
result[0] = loc try:
if r.module == @[]: result[0] = loc
result[1] = loc.bundle.modules[loc.schemaPath].field0.definitions[r.name] if r.module == @[]:
else: result[1] = loc.bundle.modules[loc.schemaPath].field0.definitions[r.name]
result[0].schemaPath = r.module else:
result[1] = loc.bundle.modules[r.module].field0.definitions[r.name] result[0].schemaPath = r.module
result[1] = loc.bundle.modules[r.module].field0.definitions[r.name]
except KeyError:
raise newException(KeyError, "reference not found in bundle: " & $r)
proc hasEmbeddedType(scm: Schema): bool = proc hasEmbeddedType(scm: Schema): bool =
case scm.field0.embeddedType.orKind case scm.field0.embeddedType.orKind
@ -152,8 +152,7 @@ proc attrs(loc: Location; sp: SimplePattern; seen: RefSet): Attributes =
of SimplepatternKind.dictof: of SimplepatternKind.dictof:
attrs(loc, sp.dictof.key, seen) + attrs(loc, sp.dictof.value, seen) attrs(loc, sp.dictof.key, seen) + attrs(loc, sp.dictof.value, seen)
of SimplepatternKind.Ref: of SimplepatternKind.Ref:
if sp.ref in seen: {recursive} if (sp.ref in seen) or sp.ref.isAtomic: {}
elif sp.ref.isAtomic: {}
else: else:
var var
(loc, def) = deref(loc, sp.ref) (loc, def) = deref(loc, sp.ref)
@ -211,8 +210,85 @@ proc attrs(loc: Location; p: Definition|DefinitionOr|DefinitionAnd|Pattern|Compo
proc isEmbedded(loc: Location; p: Definition|DefinitionOr|DefinitionAnd|Pattern|CompoundPattern|SimplePattern): bool = proc isEmbedded(loc: Location; p: Definition|DefinitionOr|DefinitionAnd|Pattern|CompoundPattern|SimplePattern): bool =
embedded in attrs(loc, p) embedded in attrs(loc, p)
proc isRecursive(loc: Location; p: Definition|DefinitionOr|DefinitionAnd|Pattern|CompoundPattern): bool = proc isRecursive(loc: Location; name: string; pat: Pattern; seen: RefSet): bool {.gcsafe.}
recursive in attrs(loc, p)
proc isRecursive(loc: Location; name: string; def: Definition; seen: RefSet): bool {.gcsafe.}
proc isRecursive(loc: Location; name: string; n: NamedAlternative|NamedPattern; seen: RefSet): bool =
isRecursive(loc, name, n.pattern, seen)
proc isRecursive(loc: Location; name: string; sp: SimplePattern; seen: RefSet): bool =
case sp.orKind
of SimplepatternKind.embedded:
isRecursive(loc, name, sp.embedded.interface, seen)
of SimplepatternKind.Ref:
if sp.ref.name.string == name: true
elif sp.ref in seen: false
else:
var
(loc, def) = deref(loc, sp.ref)
seen = seen
incl(seen, sp.ref)
isRecursive(loc, name, def, seen)
else:
false
# seqof, setof, and dictof are not processed
# because they imply pointer indirection
proc isRecursive(loc: Location; name: string; np: NamedSimplePattern; seen: RefSet): bool =
case np.orKind
of NamedSimplePatternKind.named:
isRecursive(loc, name, np.named.pattern, seen)
of NamedSimplePatternKind.anonymous:
isRecursive(loc, name, np.anonymous, seen)
proc isRecursive(loc: Location; name: string; cp: CompoundPattern; seen: RefSet): bool =
case cp.orKind
of CompoundPatternKind.rec:
result =
isRecursive(loc, name, cp.rec.label.pattern, seen) or
isRecursive(loc, name, cp.rec.fields.pattern, seen)
of CompoundPatternKind.tuple:
for np in cp.tuple.patterns:
if result: return
result = isRecursive(loc, name, np.pattern, seen)
of CompoundPatternKind.tupleprefix:
result = isRecursive(loc, name, cp.tupleprefix.variable, seen)
for p in cp.tupleprefix.fixed:
if result: return
result = isRecursive(loc, name, p, seen)
of CompoundPatternKind.dict:
for nsp in cp.dict.entries.values:
if result: return
result = isRecursive(loc, name, nsp, seen)
proc isRecursive(loc: Location; name: string; pat: Pattern; seen: RefSet): bool =
case pat.orKind
of PatternKind.SimplePattern:
isRecursive(loc, name, pat.simplePattern, seen)
of PatternKind.CompoundPattern:
isRecursive(loc, name, pat.compoundPattern, seen)
proc isRecursive(loc: Location; name: string; def: DefinitionOr|DefinitionAnd; seen: RefSet): bool =
result =
isRecursive(loc, name, def.field0.pattern0, seen) or
isRecursive(loc, name, def.field0.pattern1, seen)
for p in def.field0.patternN:
if result: return
result = isRecursive(loc, name, p, seen)
proc isRecursive(loc: Location; name: string; def: Definition; seen: RefSet): bool =
case def.orKind
of DefinitionKind.or:
isRecursive(loc, name, def.or, seen)
of DefinitionKind.and:
isRecursive(loc, name, def.and, seen)
of DefinitionKind.Pattern:
isRecursive(loc, name, def.pattern, seen)
proc isRecursive(loc: Location; name: string; def: Definition): bool =
var seen: RefSet
isRecursive(loc, name, def, seen)
proc isLiteral(loc: Location; def: Definition): bool {.gcsafe.} proc isLiteral(loc: Location; def: Definition): bool {.gcsafe.}
proc isLiteral(loc: Location; pat: Pattern): bool {.gcsafe.} proc isLiteral(loc: Location; pat: Pattern): bool {.gcsafe.}
@ -466,7 +542,6 @@ proc typeDef(loc: Location; name: string; pat: Pattern; ty: PNode): PNode =
case pat.orKind case pat.orKind
of PatternKind.CompoundPattern: of PatternKind.CompoundPattern:
let pragma = newNode(nkPragma) let pragma = newNode(nkPragma)
if isRecursive(loc, pat): pragma.add(ident"acyclic")
case pat.compoundPattern.orKind case pat.compoundPattern.orKind
of CompoundPatternKind.rec: of CompoundPatternKind.rec:
if isLiteral(loc, pat.compoundPattern.rec.label): if isLiteral(loc, pat.compoundPattern.rec.label):
@ -491,8 +566,12 @@ proc typeDef(loc: Location; name: string; pat: Pattern; ty: PNode): PNode =
proc typeDef(loc: Location; name: string; def: Definition; ty: PNode): PNode = proc typeDef(loc: Location; name: string; def: Definition; ty: PNode): PNode =
case def.orKind case def.orKind
of DefinitionKind.or: of DefinitionKind.or:
var ty = ty
let pragma = newNode(nkPragma) let pragma = newNode(nkPragma)
if isRecursive(loc, def): pragma.add(ident"acyclic") if isRecursive(loc, name, def):
doAssert ty.kind == nkObjectTy
pragma.add(ident"acyclic")
ty = nkRefTy.newTree(ty)
pragma.add(ident"preservesOr") pragma.add(ident"preservesOr")
if isSymbolEnum(loc, def): if isSymbolEnum(loc, def):
pragma.add ident"pure" pragma.add ident"pure"
@ -656,8 +735,6 @@ proc nimTypeOf(loc: Location; known: var TypeTable; name: string; cp: CompoundPa
of CompoundPatternKind.`dict`: of CompoundPatternKind.`dict`:
result.node = nkObjectTy.newTree(newEmpty(), newEmpty(), result.node = nkObjectTy.newTree(newEmpty(), newEmpty(),
newNode(nkRecList).addFields(loc, known, name, cp.dict.entries)) newNode(nkRecList).addFields(loc, known, name, cp.dict.entries))
if result.node.kind == nkObjectTy and isRecursive(loc, cp):
result.node = nkRefTy.newTree(result.node)
proc nimTypeOf(loc: Location; known: var TypeTable; name: string; pat: Pattern): TypeSpec = proc nimTypeOf(loc: Location; known: var TypeTable; name: string; pat: Pattern): TypeSpec =
case pat.orKind case pat.orKind
@ -727,9 +804,6 @@ proc nimTypeOf(loc: Location; known: var TypeTable; name: string; orDef: Definit
newEmpty(), newEmpty(),
newEmpty(), newEmpty(),
nkRecList.newTree(recCase)) nkRecList.newTree(recCase))
# result.attrs = attrs(loc, orDef)
if result.node.kind == nkObjectTy and (recursive in attrs(loc, orDef)):
result.node = nkRefTy.newTree(result.node)
proc nimTypeOf(loc: Location; known: var TypeTable; name: string; def: DefinitionAnd): TypeSpec = proc nimTypeOf(loc: Location; known: var TypeTable; name: string; def: DefinitionAnd): TypeSpec =
if isDictionary(loc, def): if isDictionary(loc, def):
@ -953,7 +1027,6 @@ when isMainModule:
for inputPath in inputs: for inputPath in inputs:
var bundle: Bundle var bundle: Bundle
if dirExists inputPath: if dirExists inputPath:
new bundle
for filePath in walkDirRec(inputPath, relative = true): for filePath in walkDirRec(inputPath, relative = true):
var (dirPath, fileName, fileExt) = splitFile(filePath) var (dirPath, fileName, fileExt) = splitFile(filePath)
if fileExt == ".prs": if fileExt == ".prs":
@ -974,10 +1047,9 @@ when isMainModule:
if fromPreserves(schema, pr): if fromPreserves(schema, pr):
bundle.modules[@[Symbol fileName]] = schema bundle.modules[@[Symbol fileName]] = schema
else: else:
new bundle
var scm = parsePreservesSchema(readFile(inputPath), dirPath) var scm = parsePreservesSchema(readFile(inputPath), dirPath)
bundle.modules[@[Symbol fileName]] = scm bundle.modules[@[Symbol fileName]] = scm
if bundle.isNil or bundle.modules.len == 0: if bundle.modules.len == 0:
quit "Failed to recognize " & inputPath quit "Failed to recognize " & inputPath
else: else:
writeModules(bundle) writeModules(bundle)

View File

@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: ☭ Emery Hemingway # SPDX-FileCopyrightText: ☭ Emery Hemingway
# SPDX-License-Identifier: Unlicense # SPDX-License-Identifier: Unlicense
import std/[hashes, options, os, parseopt, streams, strutils, tables] import std/[hashes, os, parseopt, streams, strutils, tables]
import ../preserves, ./schema, ./schemaparse import ../preserves, ./schema, ./schemaparse
@ -36,7 +36,7 @@ when isMainModule:
write(outStream, schema.toPreserves) write(outStream, schema.toPreserves)
else: else:
let bundle = Bundle() var bundle: Bundle
if not dirExists inputPath: if not dirExists inputPath:
quit "not a directory of schemas: " & inputPath quit "not a directory of schemas: " & inputPath
else: else:

View File

@ -8,23 +8,23 @@ type
`name`*: Symbol `name`*: Symbol
ModulePath* = seq[Symbol] ModulePath* = seq[Symbol]
Bundle* {.acyclic, preservesRecord: "bundle".} = ref object Bundle* {.preservesRecord: "bundle".} = object
`modules`*: Modules `modules`*: Modules
CompoundPatternKind* {.pure.} = enum CompoundPatternKind* {.pure.} = enum
`rec`, `tuple`, `tuplePrefix`, `dict` `rec`, `tuple`, `tuplePrefix`, `dict`
CompoundPatternRec* {.acyclic, preservesRecord: "rec".} = ref object CompoundPatternRec* {.preservesRecord: "rec".} = object
`label`*: NamedPattern `label`*: NamedPattern
`fields`*: NamedPattern `fields`*: NamedPattern
CompoundPatternTuple* {.acyclic, preservesRecord: "tuple".} = ref object CompoundPatternTuple* {.preservesRecord: "tuple".} = object
`patterns`*: seq[NamedPattern] `patterns`*: seq[NamedPattern]
CompoundPatternTuplePrefix* {.acyclic, preservesRecord: "tuplePrefix".} = ref object CompoundPatternTuplePrefix* {.preservesRecord: "tuplePrefix".} = object
`fixed`*: seq[NamedPattern] `fixed`*: seq[NamedPattern]
`variable`*: NamedSimplePattern `variable`*: NamedSimplePattern
CompoundPatternDict* {.acyclic, preservesRecord: "dict".} = ref object CompoundPatternDict* {.preservesRecord: "dict".} = object
`entries`*: DictionaryEntries `entries`*: DictionaryEntries
`CompoundPattern`* {.acyclic, preservesOr.} = ref object `CompoundPattern`* {.acyclic, preservesOr.} = ref object
@ -75,19 +75,19 @@ type
SimplePatternAtom* {.preservesRecord: "atom".} = object SimplePatternAtom* {.preservesRecord: "atom".} = object
`atomKind`*: AtomKind `atomKind`*: AtomKind
SimplePatternEmbedded* {.acyclic, preservesRecord: "embedded".} = ref object SimplePatternEmbedded* {.preservesRecord: "embedded".} = object
`interface`*: SimplePattern `interface`*: SimplePattern
SimplePatternLit* {.preservesRecord: "lit".} = object SimplePatternLit* {.preservesRecord: "lit".} = object
`value`*: Value `value`*: Value
SimplePatternSeqof* {.acyclic, preservesRecord: "seqof".} = ref object SimplePatternSeqof* {.preservesRecord: "seqof".} = object
`pattern`*: SimplePattern `pattern`*: SimplePattern
SimplePatternSetof* {.acyclic, preservesRecord: "setof".} = ref object SimplePatternSetof* {.preservesRecord: "setof".} = object
`pattern`*: SimplePattern `pattern`*: SimplePattern
SimplePatternDictof* {.acyclic, preservesRecord: "dictof".} = ref object SimplePatternDictof* {.preservesRecord: "dictof".} = object
`key`*: SimplePattern `key`*: SimplePattern
`value`*: SimplePattern `value`*: SimplePattern
@ -120,7 +120,7 @@ type
NamedSimplePatternKind* {.pure.} = enum NamedSimplePatternKind* {.pure.} = enum
`named`, `anonymous` `named`, `anonymous`
`NamedSimplePattern`* {.acyclic, preservesOr.} = ref object `NamedSimplePattern`* {.preservesOr.} = object
case orKind*: NamedSimplePatternKind case orKind*: NamedSimplePatternKind
of NamedSimplePatternKind.`named`: of NamedSimplePatternKind.`named`:
`named`*: Binding `named`*: Binding
@ -131,23 +131,23 @@ type
DefinitionKind* {.pure.} = enum DefinitionKind* {.pure.} = enum
`or`, `and`, `Pattern` `or`, `and`, `Pattern`
DefinitionOrField0* {.acyclic, preservesTuple.} = ref object DefinitionOrField0* {.preservesTuple.} = object
`pattern0`*: NamedAlternative `pattern0`*: NamedAlternative
`pattern1`*: NamedAlternative `pattern1`*: NamedAlternative
`patternN`* {.preservesTupleTail.}: seq[NamedAlternative] `patternN`* {.preservesTupleTail.}: seq[NamedAlternative]
DefinitionOr* {.acyclic, preservesRecord: "or".} = ref object DefinitionOr* {.preservesRecord: "or".} = object
`field0`*: DefinitionOrField0 `field0`*: DefinitionOrField0
DefinitionAndField0* {.acyclic, preservesTuple.} = ref object DefinitionAndField0* {.preservesTuple.} = object
`pattern0`*: NamedPattern `pattern0`*: NamedPattern
`pattern1`*: NamedPattern `pattern1`*: NamedPattern
`patternN`* {.preservesTupleTail.}: seq[NamedPattern] `patternN`* {.preservesTupleTail.}: seq[NamedPattern]
DefinitionAnd* {.acyclic, preservesRecord: "and".} = ref object DefinitionAnd* {.preservesRecord: "and".} = object
`field0`*: DefinitionAndField0 `field0`*: DefinitionAndField0
`Definition`* {.acyclic, preservesOr.} = ref object `Definition`* {.preservesOr.} = object
case orKind*: DefinitionKind case orKind*: DefinitionKind
of DefinitionKind.`or`: of DefinitionKind.`or`:
`or`*: DefinitionOr `or`*: DefinitionOr
@ -159,16 +159,16 @@ type
`pattern`*: Pattern `pattern`*: Pattern
NamedAlternative* {.acyclic, preservesTuple.} = ref object NamedAlternative* {.preservesTuple.} = object
`variantLabel`*: string `variantLabel`*: string
`pattern`*: Pattern `pattern`*: Pattern
SchemaField0* {.acyclic, preservesDictionary.} = ref object SchemaField0* {.preservesDictionary.} = object
`definitions`*: Definitions `definitions`*: Definitions
`embeddedType`*: EmbeddedTypeName `embeddedType`*: EmbeddedTypeName
`version`* {.preservesLiteral: "1".}: tuple[] `version`* {.preservesLiteral: "1".}: tuple[]
Schema* {.acyclic, preservesRecord: "schema".} = ref object Schema* {.preservesRecord: "schema".} = object
`field0`*: SchemaField0 `field0`*: SchemaField0
PatternKind* {.pure.} = enum PatternKind* {.pure.} = enum
@ -182,7 +182,7 @@ type
`compoundpattern`*: CompoundPattern `compoundpattern`*: CompoundPattern
Binding* {.acyclic, preservesRecord: "named".} = ref object Binding* {.preservesRecord: "named".} = object
`name`*: Symbol `name`*: Symbol
`pattern`*: SimplePattern `pattern`*: SimplePattern