diff --git a/src/main/java/raylras/zen/code/resolve/SymbolResolver.java b/src/main/java/raylras/zen/code/resolve/SymbolResolver.java index d1b2d735..8a3de37b 100644 --- a/src/main/java/raylras/zen/code/resolve/SymbolResolver.java +++ b/src/main/java/raylras/zen/code/resolve/SymbolResolver.java @@ -4,52 +4,56 @@ import raylras.zen.code.CompilationUnit; import raylras.zen.code.SymbolProvider; import raylras.zen.code.Visitor; +import raylras.zen.code.parser.ZenScriptParser; import raylras.zen.code.parser.ZenScriptParser.MemberAccessExprContext; import raylras.zen.code.parser.ZenScriptParser.SimpleNameExprContext; import raylras.zen.code.parser.ZenScriptParser.StatementContext; import raylras.zen.code.scope.Scope; -import raylras.zen.code.symbol.ClassSymbol; -import raylras.zen.code.symbol.Symbol; -import raylras.zen.util.PackageTree; +import raylras.zen.code.symbol.*; +import raylras.zen.util.CSTNodes; import raylras.zen.util.Ranges; +import raylras.zen.util.Symbols; import java.util.Collection; import java.util.Collections; -import java.util.Objects; +import java.util.List; import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class SymbolResolver { + public static Collection lookupSymbol(ParseTree cst, CompilationUnit unit) { - ParseTree statement = findCurrentStatement(cst); - if (statement == null) { - return Collections.emptyList(); - } - SymbolVisitor visitor = new SymbolVisitor(unit, cst); - visitor.visit(statement); - return visitor.result; + return lookupSymbol(cst, unit, true); } - private static ParseTree findCurrentStatement(ParseTree cst) { - ParseTree current = cst; - while (current != null) { - if (current instanceof StatementContext) { - return current; - } else { - current = current.getParent(); - } + public static Collection lookupSymbol(ParseTree cst, CompilationUnit unit, boolean expandImport) { + ParseTree owner = CSTNodes.findParentOfTypes(cst, + StatementContext.class, + ZenScriptParser.QualifiedNameContext.class, + ZenScriptParser.InitializerContext.class, + ZenScriptParser.DefaultValueContext.class + ); + + if (owner == null) { + return Collections.emptyList(); } - return null; + SymbolVisitor visitor = new SymbolVisitor(unit, cst, expandImport); + visitor.visit(owner); + return visitor.result; } private static class SymbolVisitor extends Visitor { private final CompilationUnit unit; private final ParseTree cst; + private final boolean expandImport; private Collection result = Collections.emptyList(); - public SymbolVisitor(CompilationUnit unit, ParseTree cst) { + public SymbolVisitor(CompilationUnit unit, ParseTree cst, boolean expandImport) { this.unit = unit; this.cst = cst; + this.expandImport = expandImport; } @Override @@ -64,27 +68,54 @@ public SymbolProvider visitSimpleNameExpr(SimpleNameExprContext ctx) { @Override public SymbolProvider visitMemberAccessExpr(MemberAccessExprContext ctx) { SymbolProvider leftPossibles = visit(ctx.expression()); + SymbolProvider foundResults = findMembers(ctx.simpleName(), leftPossibles); + if (Ranges.contains(cst, ctx.simpleName())) { + result = foundResults.getSymbols(); + } + return findMembers(ctx.simpleName(), leftPossibles); + } + + private SymbolProvider findMembers(ZenScriptParser.SimpleNameContext simpleName, SymbolProvider leftPossibles) { if (leftPossibles.size() != 1) { return SymbolProvider.EMPTY; } - Symbol leftSymbol = leftPossibles.getFirst(); SymbolProvider foundResults; - if (leftSymbol instanceof ClassSymbol classSymbol) { + if (leftSymbol instanceof PackageSymbol packageSymbol) { + foundResults = packageSymbol + .filter(isSymbolNameEquals(simpleName)); + } else if (leftSymbol instanceof ClassSymbol classSymbol) { foundResults = classSymbol .filter(Symbol::isStatic) - .filter(isSymbolNameEquals(ctx.simpleName())); + .filter(isSymbolNameEquals(simpleName)); } else if (leftSymbol.getType() instanceof SymbolProvider type) { - foundResults = type.withExpands(unit.getEnv()).filter(isSymbolNameEquals(ctx.simpleName())); + foundResults = type.withExpands(unit.getEnv()).filter(isSymbolNameEquals(simpleName)); } else { foundResults = SymbolProvider.EMPTY; } - if (Ranges.contains(cst, ctx.simpleName())) { - result = foundResults.getSymbols(); - } return foundResults; } + @Override + public SymbolProvider visitQualifiedName(ZenScriptParser.QualifiedNameContext ctx) { + List simpleNames = ctx.simpleName(); + if (simpleNames.isEmpty()) { + return SymbolProvider.EMPTY; + } + + SymbolProvider possibles = lookupSymbol(ctx, simpleNames.get(0)); + if (Ranges.contains(cst, simpleNames.get(0))) { + result = possibles.getSymbols(); + } + for (int i = 1; i < simpleNames.size(); i++) { + possibles = findMembers(simpleNames.get(i), possibles); + if (Ranges.contains(cst, simpleNames.get(i))) { + result = possibles.getSymbols(); + } + } + return possibles; + } + @Override protected SymbolProvider defaultResult() { return SymbolProvider.EMPTY; @@ -96,22 +127,35 @@ private SymbolProvider lookupSymbol(ParseTree cst, ParseTree name) { private SymbolProvider lookupSymbol(ParseTree cst, String name) { Scope scope = unit.lookupScope(cst); - if (scope != null) { - return scope.filter(isSymbolNameEquals(name)) - .orElse(() -> lookupGlobalSymbols(name)); - } else { - return () -> lookupGlobalSymbols(name); + if(scope == null) { + return SymbolProvider.EMPTY; } + List result = scope.lookupSymbols(name); + + if (result.isEmpty()) { + return SymbolProvider.of(Symbols.lookupGlobalSymbols(unit, name)); + } + if (expandImport) { + result = result.stream().flatMap(it -> { + // try to expand import + Collection importTargets = tryExpandImport(it); + if (!importTargets.isEmpty()) { + return importTargets.stream(); + } + return Stream.of(it); + }).collect(Collectors.toList()); + } + return SymbolProvider.of(result); } - private Collection lookupGlobalSymbols(String name) { - Collection globals = unit.getEnv().getGlobalSymbols().stream().filter(it -> Objects.equals(it.getName(), name)).toList(); - if (globals.isEmpty()) { - // TODO: find package - PackageTree packageTree = PackageTree.of(".", unit.getEnv().getClassSymbolMap()); - return Collections.emptyList(); + private Collection tryExpandImport(Symbol symbol) { + if (symbol instanceof ImportSymbol && symbol instanceof ParseTreeLocatable locatable && locatable.getCst() instanceof ZenScriptParser.ImportDeclarationContext importCtx) { + Collection importTargets = this.visit(importCtx.qualifiedName()).getSymbols(); + if (!importTargets.isEmpty()) { + return importTargets; + } } - return globals; + return Collections.emptyList(); } private static Predicate isSymbolNameEquals(ParseTree name) { diff --git a/src/main/java/raylras/zen/code/symbol/ImportSymbol.java b/src/main/java/raylras/zen/code/symbol/ImportSymbol.java index 3a19226d..3a8a4996 100644 --- a/src/main/java/raylras/zen/code/symbol/ImportSymbol.java +++ b/src/main/java/raylras/zen/code/symbol/ImportSymbol.java @@ -1,11 +1,7 @@ package raylras.zen.code.symbol; -import java.util.List; - public interface ImportSymbol extends Symbol { String getQualifiedName(); - List getTargets(); - } diff --git a/src/main/java/raylras/zen/code/symbol/SymbolFactory.java b/src/main/java/raylras/zen/code/symbol/SymbolFactory.java index 355018cd..bd207c2e 100644 --- a/src/main/java/raylras/zen/code/symbol/SymbolFactory.java +++ b/src/main/java/raylras/zen/code/symbol/SymbolFactory.java @@ -1,17 +1,23 @@ package raylras.zen.code.symbol; import org.antlr.v4.runtime.tree.ParseTree; +import raylras.zen.code.CompilationEnvironment; import raylras.zen.code.CompilationUnit; import raylras.zen.code.parser.ZenScriptParser.*; import raylras.zen.code.resolve.FormalParameterResolver; import raylras.zen.code.resolve.ModifierResolver; import raylras.zen.code.resolve.TypeResolver; import raylras.zen.code.scope.Scope; -import raylras.zen.code.type.*; +import raylras.zen.code.type.AnyType; +import raylras.zen.code.type.ClassType; +import raylras.zen.code.type.FunctionType; +import raylras.zen.code.type.Type; import raylras.zen.util.CSTNodes; import raylras.zen.util.Operators; +import raylras.zen.util.PackageTree; import raylras.zen.util.Range; +import java.nio.file.Path; import java.util.*; import java.util.function.UnaryOperator; import java.util.stream.Collectors; @@ -31,13 +37,6 @@ class ImportSymbolImpl implements ImportSymbol, ParseTreeLocatable { public String getQualifiedName() { return cst.qualifiedName().getText(); } - - @Override - public List getTargets() { - // TODO: import static members and implement this - return Collections.emptyList(); - } - @Override public String getName() { return name; @@ -50,8 +49,7 @@ public Kind getKind() { @Override public Type getType() { - // TODO: import static members - return unit.getEnv().getClassTypeMap().get(getQualifiedName()); + return AnyType.INSTANCE; } @Override @@ -614,6 +612,94 @@ public Modifier getModifier() { return new ParameterSymbolImpl(); } + public static PackageSymbol createPackageSymbol(String name, String qualifiedName, PackageTree packageTree, CompilationEnvironment environment, boolean isGenerated) { + Path root; + if (isGenerated) { + root = environment.getGeneratedRoot(); + } else { + root = environment.getRoot(); + } + if (packageTree.hasElement()) { + String relativePath = qualifiedName.replace(".", "/"); + if (isGenerated) { + relativePath += ".dzs"; + } else { + relativePath += ".zs"; + } + root = root.resolve(relativePath).normalize(); + } else { + root = root.resolve(qualifiedName.replace(".", "/")).normalize(); + } + + final Path rootFinal = root; + class PackageSymbolImpl implements PackageSymbol, Locatable { + + @Override + public String getName() { + return name; + } + + @Override + public Kind getKind() { + return Kind.PACKAGE; + } + + @Override + public Type getType() { + return AnyType.INSTANCE; + } + + @Override + public Modifier getModifier() { + return Modifier.NONE; + } + + @Override + public Path getPath() { + return rootFinal; + } + + @Override + public String getUri() { + return getPath().toUri().toString(); + } + + @Override + public Range getRange() { + return Range.NO_RANGE; + } + + @Override + public Range getSelectionRange() { + return Range.NO_RANGE; + } + + @Override + public String getQualifiedName() { + return qualifiedName; + } + + @Override + public List getSymbols() { + List result = new ArrayList<>(); + for (Map.Entry> entry : packageTree.getSubTrees().entrySet()) { + PackageTree tree = entry.getValue(); + if (tree.hasElement()) { + result.add(tree.getElement()); + } else { + String name = entry.getKey(); + PackageSymbol packageSymbol = createPackageSymbol(name, qualifiedName + "." + name, entry.getValue(), environment, isGenerated); + result.add(packageSymbol); + } + } + return result; + } + } + + return new PackageSymbolImpl(); + } + + public static SymbolsBuilder builtinSymbols() { return new SymbolsBuilder(); } diff --git a/src/main/java/raylras/zen/langserver/provider/DefinitionProvider.java b/src/main/java/raylras/zen/langserver/provider/DefinitionProvider.java index ecc19aa1..68c3d696 100644 --- a/src/main/java/raylras/zen/langserver/provider/DefinitionProvider.java +++ b/src/main/java/raylras/zen/langserver/provider/DefinitionProvider.java @@ -8,7 +8,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import raylras.zen.code.resolve.SymbolResolver; -import raylras.zen.code.symbol.ParseTreeLocatable; +import raylras.zen.code.symbol.Locatable; import raylras.zen.code.symbol.Symbol; import raylras.zen.langserver.Document; import raylras.zen.util.CSTNodes; @@ -27,10 +27,14 @@ public static CompletableFuture, List CompletableFuture.supplyAsync(() -> { Position cursor = Position.of(params.getPosition()); ParseTree cst = CSTNodes.getCstAtPosition(unit.getParseTree(), cursor); + if(cst == null) { + logger.warn("Could not get symbol at ({}, {}), skipping goto definition", params.getPosition().getLine(), params.getPosition().getCharacter()); + return null; + } org.eclipse.lsp4j.Range originSelectionRange = Range.of(cst).toLspRange(); Collection symbols = SymbolResolver.lookupSymbol(cst, unit); return Either., List>forRight(symbols.stream() - .filter(symbol -> symbol instanceof ParseTreeLocatable) + .filter(symbol -> symbol instanceof Locatable) .map(symbol -> toLocationLink(symbol, originSelectionRange)) .toList()); })).orElseGet(DefinitionProvider::empty); @@ -41,7 +45,7 @@ public static CompletableFuture, List> references(Document do List result = getSearchingScope(symbol, unit).stream().parallel().flatMap(cu -> { String uri = cu.getPath().toUri().toString(); return searchPossible(searchRule, cu.getParseTree()) - .stream().parallel().filter(cst -> { - Collection symbols = SymbolResolver.lookupSymbol(cst, unit); - return symbols.stream().anyMatch(it -> Objects.equals(it, symbol)); - }) + .stream().parallel().filter(cst -> isReferenceToSymbol(cu, cst, symbol)) .map(it -> toLocation(uri, it)); } ).toList(); @@ -104,6 +101,16 @@ public void visitTerminal(TerminalNode node) { return result; } + private static boolean isReferenceToSymbol(CompilationUnit unit, ParseTree cst, Symbol targetSymbol) { + if (targetSymbol.getKind() == Symbol.Kind.IMPORT) { + Collection symbols = SymbolResolver.lookupSymbol(cst, unit, false); + return symbols.stream().anyMatch(it -> Objects.equals(it, targetSymbol)); + } else { + Collection symbols = SymbolResolver.lookupSymbol(cst, unit); + return symbols.stream().anyMatch(it -> Objects.equals(it, targetSymbol)); + } + } + private static Predicate getSymbolSearchRule(Symbol symbol) { if (symbol instanceof OperatorFunctionSymbol operator) { String opName = operator.getName(); @@ -148,7 +155,7 @@ private static boolean isGloballyAccessibleSymbol(Symbol symbol) { return true; } - if (symbol.getKind() == Symbol.Kind.VARIABLE && symbol instanceof ParseTreeLocatable locatable) { + if (symbol instanceof ParseTreeLocatable locatable) { ParseTree parent = CSTNodes.findParentOfTypes(locatable.getCst(), ZenScriptParser.ClassDeclarationContext.class, ZenScriptParser.BlockStatementContext.class); // variables and functions in classes are accessible by other units. if (parent instanceof ZenScriptParser.ClassDeclarationContext) { diff --git a/src/main/java/raylras/zen/util/Symbols.java b/src/main/java/raylras/zen/util/Symbols.java index 4a88ec57..301e68ec 100644 --- a/src/main/java/raylras/zen/util/Symbols.java +++ b/src/main/java/raylras/zen/util/Symbols.java @@ -1,12 +1,12 @@ package raylras.zen.util; import raylras.zen.code.CompilationEnvironment; +import raylras.zen.code.CompilationUnit; import raylras.zen.code.SymbolProvider; -import raylras.zen.code.symbol.Symbol; -import raylras.zen.code.symbol.Executable; +import raylras.zen.code.symbol.*; import raylras.zen.code.type.Type; -import java.util.Collections; -import java.util.List; + +import java.util.*; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -40,4 +40,33 @@ public static List getMember(Type type, Class clazz, Co } + + public static List lookupGlobalSymbols(CompilationUnit unit, String name) { + List globals = unit.getEnv().getGlobalSymbols() + .stream() + .filter(it -> !(it instanceof ParseTreeLocatable locatable && locatable.getUnit() == unit)) + .filter(it -> Objects.equals(it.getName(), name)) + .toList(); + if (globals.isEmpty()) { + return getTopLevelPackageSymbols(unit) + .stream() + .filter(it -> Objects.equals(it.getName(), name)) + .toList(); + } + return globals; + } + + private static List getTopLevelPackageSymbols(CompilationUnit unit) { + PackageTree packageTree = PackageTree.of(".", unit.getEnv().getClassSymbolMap()); + List globalPackages = new ArrayList<>(packageTree.getSubTrees().size()); + CompilationEnvironment environment = unit.getEnv(); + for (Map.Entry> entry : packageTree.getSubTrees().entrySet()) { + boolean isGenerated = !"scripts".equals(entry.getKey()); + PackageSymbol packageSymbol = SymbolFactory.createPackageSymbol(entry.getKey(), entry.getKey(), entry.getValue(), environment, isGenerated); + globalPackages.add(packageSymbol); + } + return globalPackages; + } + + }