From 10682883a4fff8e532479cf64883e92c3d5f9fdc Mon Sep 17 00:00:00 2001 From: Emery Hemingway Date: Tue, 28 Feb 2023 22:05:44 -0600 Subject: [PATCH] preserves_schema_nim: detect schema self references Do not add a schema module as an import to itself when a definition refers to a type with a qualified module name that is the same as the schema the definition occurs in. --- src/preserves/preserves_schema_nim.nim | 50 +++++++++++++------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/src/preserves/preserves_schema_nim.nim b/src/preserves/preserves_schema_nim.nim index 8c3bc5d..ace740e 100644 --- a/src/preserves/preserves_schema_nim.nim +++ b/src/preserves/preserves_schema_nim.nim @@ -716,56 +716,56 @@ proc generateProcs(result: var seq[PNode]; scm: Schema; name: string; pat: Patte proc generateProcs(result: var seq[PNode]; scm: Schema; name: string; def: Definition) = discard -proc collectRefImports(imports: PNode; pat: Pattern) +proc collectRefImports(loc: Location; imports: PNode; pat: Pattern) -proc collectRefImports(imports: PNode; sp: SimplePattern) = +proc collectRefImports(loc: Location; imports: PNode; sp: SimplePattern) = case sp.orKind of SimplePatternKind.dictof: imports.add ident"std/tables" of SimplePatternKind.Ref: - if sp.`ref`.module != @[]: + if sp.`ref`.module != @[] and sp.`ref`.module != loc.schemaPath: imports.add ident(string sp.ref.module[0]) else: discard -proc collectRefImports(imports: PNode; cp: CompoundPattern) = +proc collectRefImports(loc: Location; imports: PNode; cp: CompoundPattern) = case cp.orKind of CompoundPatternKind.`rec`: - collectRefImports(imports, cp.rec.label.pattern) - collectRefImports(imports, cp.rec.fields.pattern) + collectRefImports(loc, imports, cp.rec.label.pattern) + collectRefImports(loc, imports, cp.rec.fields.pattern) of CompoundPatternKind.`tuple`: - for p in cp.tuple.patterns: collectRefImports(imports, p.pattern) + for p in cp.tuple.patterns: collectRefImports(loc, imports, p.pattern) of CompoundPatternKind.`tupleprefix`: - for np in cp.tupleprefix.fixed: collectRefImports(imports, np.pattern) - collectRefImports(imports, cp.tupleprefix.variable.pattern) + for np in cp.tupleprefix.fixed: collectRefImports(loc, imports, np.pattern) + collectRefImports(loc, imports, cp.tupleprefix.variable.pattern) of CompoundPatternKind.`dict`: for nsp in cp.dict.entries.values: - collectRefImports(imports, nsp.pattern) + collectRefImports(loc, imports, nsp.pattern) -proc collectRefImports(imports: PNode; pat: Pattern) = +proc collectRefImports(loc: Location; imports: PNode; pat: Pattern) = case pat.orKind of PatternKind.SimplePattern: - collectRefImports(imports, pat.simplePattern) + collectRefImports(loc, imports, pat.simplePattern) of PatternKind.CompoundPattern: - collectRefImports(imports, pat.compoundPattern) + collectRefImports(loc, imports, pat.compoundPattern) -proc collectRefImports(imports: PNode; def: Definition) = +proc collectRefImports(loc: Location; imports: PNode; def: Definition) = case def.orKind of DefinitionKind.`or`: - collectRefImports(imports, def.or.data.pattern0.pattern) - collectRefImports(imports, def.or.data.pattern1.pattern) + collectRefImports(loc, imports, def.or.data.pattern0.pattern) + collectRefImports(loc, imports, def.or.data.pattern1.pattern) for na in def.or.data.patternN: - collectRefImports(imports, na.pattern) + collectRefImports(loc, imports, na.pattern) of DefinitionKind.`and`: - collectRefImports(imports, def.and.data.pattern0.pattern) - collectRefImports(imports, def.and.data.pattern1.pattern) + collectRefImports(loc, imports, def.and.data.pattern0.pattern) + collectRefImports(loc, imports, def.and.data.pattern1.pattern) for np in def.and.data.patternN: - collectRefImports(imports, np.pattern) + collectRefImports(loc, imports, np.pattern) of DefinitionKind.Pattern: - collectRefImports(imports, def.pattern) + collectRefImports(loc, imports, def.pattern) -proc collectRefImports(imports: PNode; scm: Schema) = +proc collectRefImports(loc: Location; imports: PNode; scm: Schema) = for _, def in scm.data.definitions: - collectRefImports(imports, def) + collectRefImports(loc, imports, def) proc mergeType(x: var PNode; y: PNode) = if x.isNil: x = y @@ -781,12 +781,12 @@ proc renderNimBundle*(bundle: Bundle): Table[string, string] = result = initTable[string, string](bundle.modules.len) var typeDefs: TypeTable for scmPath, scm in bundle.modules: + let loc = (bundle, scmPath) var typeSection = newNode nkTypeSection procs: seq[PNode] unembeddableType, embeddableType: PNode for name, def in scm.data.definitions.pairs: - let loc = (bundle, scmPath) if isLiteral(loc, def): generateConstProcs(procs, scm, string name, def) else: @@ -809,7 +809,7 @@ proc renderNimBundle*(bundle: Bundle): Table[string, string] = var imports = nkImportStmt.newNode.add( ident"std/typetraits", ident"preserves") - collectRefImports(imports, scm) + collectRefImports(loc, imports, scm) if not embeddableType.isNil: let genericParams = nn(nkGenericParams, nn(nkIdentDefs, embeddedIdent(scm), newEmpty(), newEmpty()))