preserves_schema_nim: detect schema self references

Do not add a schema module as an import to itself when a
definition refers to a type with a qualified module name that is
the same as the schema the definition occurs in.
This commit is contained in:
Emery Hemingway 2023-02-28 22:05:44 -06:00
parent 70655a959b
commit 10682883a4
1 changed files with 25 additions and 25 deletions

View File

@ -716,56 +716,56 @@ proc generateProcs(result: var seq[PNode]; scm: Schema; name: string; pat: Patte
proc generateProcs(result: var seq[PNode]; scm: Schema; name: string; def: Definition) =
discard
proc collectRefImports(imports: PNode; pat: Pattern)
proc collectRefImports(loc: Location; imports: PNode; pat: Pattern)
proc collectRefImports(imports: PNode; sp: SimplePattern) =
proc collectRefImports(loc: Location; imports: PNode; sp: SimplePattern) =
case sp.orKind
of SimplePatternKind.dictof:
imports.add ident"std/tables"
of SimplePatternKind.Ref:
if sp.`ref`.module != @[]:
if sp.`ref`.module != @[] and sp.`ref`.module != loc.schemaPath:
imports.add ident(string sp.ref.module[0])
else: discard
proc collectRefImports(imports: PNode; cp: CompoundPattern) =
proc collectRefImports(loc: Location; imports: PNode; cp: CompoundPattern) =
case cp.orKind
of CompoundPatternKind.`rec`:
collectRefImports(imports, cp.rec.label.pattern)
collectRefImports(imports, cp.rec.fields.pattern)
collectRefImports(loc, imports, cp.rec.label.pattern)
collectRefImports(loc, imports, cp.rec.fields.pattern)
of CompoundPatternKind.`tuple`:
for p in cp.tuple.patterns: collectRefImports(imports, p.pattern)
for p in cp.tuple.patterns: collectRefImports(loc, imports, p.pattern)
of CompoundPatternKind.`tupleprefix`:
for np in cp.tupleprefix.fixed: collectRefImports(imports, np.pattern)
collectRefImports(imports, cp.tupleprefix.variable.pattern)
for np in cp.tupleprefix.fixed: collectRefImports(loc, imports, np.pattern)
collectRefImports(loc, imports, cp.tupleprefix.variable.pattern)
of CompoundPatternKind.`dict`:
for nsp in cp.dict.entries.values:
collectRefImports(imports, nsp.pattern)
collectRefImports(loc, imports, nsp.pattern)
proc collectRefImports(imports: PNode; pat: Pattern) =
proc collectRefImports(loc: Location; imports: PNode; pat: Pattern) =
case pat.orKind
of PatternKind.SimplePattern:
collectRefImports(imports, pat.simplePattern)
collectRefImports(loc, imports, pat.simplePattern)
of PatternKind.CompoundPattern:
collectRefImports(imports, pat.compoundPattern)
collectRefImports(loc, imports, pat.compoundPattern)
proc collectRefImports(imports: PNode; def: Definition) =
proc collectRefImports(loc: Location; imports: PNode; def: Definition) =
case def.orKind
of DefinitionKind.`or`:
collectRefImports(imports, def.or.data.pattern0.pattern)
collectRefImports(imports, def.or.data.pattern1.pattern)
collectRefImports(loc, imports, def.or.data.pattern0.pattern)
collectRefImports(loc, imports, def.or.data.pattern1.pattern)
for na in def.or.data.patternN:
collectRefImports(imports, na.pattern)
collectRefImports(loc, imports, na.pattern)
of DefinitionKind.`and`:
collectRefImports(imports, def.and.data.pattern0.pattern)
collectRefImports(imports, def.and.data.pattern1.pattern)
collectRefImports(loc, imports, def.and.data.pattern0.pattern)
collectRefImports(loc, imports, def.and.data.pattern1.pattern)
for np in def.and.data.patternN:
collectRefImports(imports, np.pattern)
collectRefImports(loc, imports, np.pattern)
of DefinitionKind.Pattern:
collectRefImports(imports, def.pattern)
collectRefImports(loc, imports, def.pattern)
proc collectRefImports(imports: PNode; scm: Schema) =
proc collectRefImports(loc: Location; imports: PNode; scm: Schema) =
for _, def in scm.data.definitions:
collectRefImports(imports, def)
collectRefImports(loc, imports, def)
proc mergeType(x: var PNode; y: PNode) =
if x.isNil: x = y
@ -781,12 +781,12 @@ proc renderNimBundle*(bundle: Bundle): Table[string, string] =
result = initTable[string, string](bundle.modules.len)
var typeDefs: TypeTable
for scmPath, scm in bundle.modules:
let loc = (bundle, scmPath)
var
typeSection = newNode nkTypeSection
procs: seq[PNode]
unembeddableType, embeddableType: PNode
for name, def in scm.data.definitions.pairs:
let loc = (bundle, scmPath)
if isLiteral(loc, def):
generateConstProcs(procs, scm, string name, def)
else:
@ -809,7 +809,7 @@ proc renderNimBundle*(bundle: Bundle): Table[string, string] =
var imports = nkImportStmt.newNode.add(
ident"std/typetraits",
ident"preserves")
collectRefImports(imports, scm)
collectRefImports(loc, imports, scm)
if not embeddableType.isNil:
let genericParams =
nn(nkGenericParams, nn(nkIdentDefs, embeddedIdent(scm), newEmpty(), newEmpty()))