diff --git a/common/src/main/java/org/opensearch/sql/common/utils/DebugUtils.java b/common/src/main/java/org/opensearch/sql/common/utils/DebugUtils.java new file mode 100644 index 00000000000..9082839360d --- /dev/null +++ b/common/src/main/java/org/opensearch/sql/common/utils/DebugUtils.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.utils; + +import java.util.Collection; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Utility class for debugging operations. This class is only for debugging purpose, and not + * intended to be used in production code. + */ +public class DebugUtils { + // Update this to true while you are debugging. (Safe guard to avoid usage in production code. ) + private static final boolean IS_DEBUG = false; + private static final Logger logger = LogManager.getLogger(DebugUtils.class); + + public static T debug(T obj, String message) { + verifyDebug(); + print("### %s: %s (at %s)", message, stringify(obj), getCalledFrom(1)); + return obj; + } + + public static T debug(T obj) { + verifyDebug(); + print("### %s (at %s)", stringify(obj), getCalledFrom(1)); + return obj; + } + + private static void verifyDebug() { + if (!IS_DEBUG) { + throw new RuntimeException("DebugUtils can be used only for local debugging."); + } + } + + private static void print(String format, Object... args) { + logger.info(String.format(format, args)); + } + + private static String getCalledFrom(int pos) { + RuntimeException e = new RuntimeException(); + StackTraceElement item = e.getStackTrace()[pos + 1]; + return item.getClassName() + "." + item.getMethodName() + ":" + item.getLineNumber(); + } + + private static String stringify(Collection items) { + if (items == null) { + return "null"; + } + + if (items.isEmpty()) { + return "()"; + } + + String result = items.stream().map(i -> stringify(i)).collect(Collectors.joining(",")); + + return "(" + result + ")"; + } + + private static String stringify(Map map) { + if (map == null) { + return "[[null]]"; + } + + if (map.isEmpty()) { + return "[[EMPTY]]"; + } + + String result = + map.entrySet().stream() + .map(entry -> entry.getKey() + ": " + stringify(entry.getValue())) + .collect(Collectors.joining(",")); + return "{" + result + "}"; + } + + private static String stringify(Object obj) { + if (obj instanceof Collection) { + return stringify((Collection) obj); + } else if (obj instanceof Map) { + return stringify((Map) obj); + } + return String.valueOf(obj); + } +} diff --git a/common/src/test/java/org/opensearch/sql/common/utils/DebugUtilsTest.java b/common/src/test/java/org/opensearch/sql/common/utils/DebugUtilsTest.java new file mode 100644 index 00000000000..b464492b607 --- /dev/null +++ b/common/src/test/java/org/opensearch/sql/common/utils/DebugUtilsTest.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.utils; + +import static org.junit.Assert.assertThrows; + +import org.junit.Test; + +public class DebugUtilsTest { + + @Test + public void testDebugThrowsRuntimeException() { + assertThrows(RuntimeException.class, () -> DebugUtils.debug("test", "test message")); + } + + @Test + public void testDebugWithoutMessageThrowsRuntimeException() { + assertThrows(RuntimeException.class, () -> DebugUtils.debug("test")); + } +} diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 24cef144c97..65056aecbad 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -104,6 +104,7 @@ import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.common.patterns.PatternUtils; import org.opensearch.sql.data.model.ExprMissingValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.datasource.DataSourceService; @@ -953,10 +954,10 @@ private Aggregation analyzePatternsAgg(Patterns node) { List aggExprs = Stream.of( new Alias( - "pattern_count", + PatternUtils.PATTERN_COUNT, new AggregateFunction(BuiltinFunctionName.COUNT.name(), AllFields.of())), new Alias( - "sample_logs", + PatternUtils.SAMPLE_LOGS, new AggregateFunction( BuiltinFunctionName.TAKE.name(), node.getSourceField(), diff --git a/core/src/main/java/org/opensearch/sql/ast/AstNodeUtils.java b/core/src/main/java/org/opensearch/sql/ast/AstNodeUtils.java new file mode 100644 index 00000000000..8de6ea57a1f --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/AstNodeUtils.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast; + +import lombok.experimental.UtilityClass; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; + +/** Utility class for AST node operations shared among visitor classes. */ +@UtilityClass +public class AstNodeUtils { + + /** + * Checks if an AST node contains a subquery expression. + * + * @param expr The AST node to check + * @return true if the node or any of its children contains a subquery expression + */ + public static boolean containsSubqueryExpression(Node expr) { + if (expr == null) { + return false; + } + if (expr instanceof SubqueryExpression) { + return true; + } + if (expr instanceof Let l) { + return containsSubqueryExpression(l.getExpression()); + } + for (Node child : expr.getChild()) { + if (containsSubqueryExpression(child)) { + return true; + } + } + return false; + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionContext.java b/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionContext.java new file mode 100644 index 00000000000..8a463f3854b --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionContext.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.analysis; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.Getter; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +/** Context for field resolution using stack-based traversal. */ +public class FieldResolutionContext { + + @Getter private final Map results; + private final Deque requirementsStack; + + public FieldResolutionContext() { + this.results = new IdentityHashMap<>(); + this.requirementsStack = new ArrayDeque<>(); + this.requirementsStack.push(new FieldResolutionResult(Set.of(), "*")); + } + + public void pushRequirements(FieldResolutionResult result) { + requirementsStack.push(result); + } + + public FieldResolutionResult popRequirements() { + return requirementsStack.pop(); + } + + public FieldResolutionResult getCurrentRequirements() { + if (requirementsStack.isEmpty()) { + throw new RuntimeException("empty stack"); + } else { + return requirementsStack.peek(); + } + } + + public void setResult(UnresolvedPlan relation, FieldResolutionResult result) { + results.put(relation, result); + } + + public Set getRelations() { + return results.keySet().stream() + .filter(k -> k instanceof Relation) + .map(k -> (Relation) k) + .collect(Collectors.toSet()); + } + + public static String mergeWildcardPatterns(Set patterns) { + if (patterns == null || patterns.isEmpty()) { + return null; + } + if (patterns.size() == 1) { + return patterns.iterator().next(); + } + return String.join(" | ", patterns.stream().sorted().toList()); + } + + @Override + public String toString() { + return "FieldResolutionContext{relationResults=" + results + "}"; + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionResult.java b/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionResult.java new file mode 100644 index 00000000000..3c433dbb4a0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionResult.java @@ -0,0 +1,271 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.analysis; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; +import lombok.Value; +import org.opensearch.sql.calcite.utils.WildcardUtils; + +/** Field resolution result separating regular fields from wildcard patterns. */ +@Getter +@EqualsAndHashCode +@ToString +public class FieldResolutionResult { + + @NonNull private final Set regularFields; + @NonNull private final Wildcard wildcard; + + public FieldResolutionResult(Set regularFields) { + this.regularFields = new HashSet<>(regularFields); + this.wildcard = NULL_WILDCARD; + } + + public FieldResolutionResult(Set regularFields, Wildcard wildcard) { + this.regularFields = new HashSet<>(regularFields); + this.wildcard = wildcard; + } + + public FieldResolutionResult(Set regularFields, String wildcard) { + this.regularFields = new HashSet<>(regularFields); + this.wildcard = getWildcard(wildcard); + } + + private static Wildcard getWildcard(String wildcard) { + if (wildcard == null || wildcard.isEmpty()) { + return NULL_WILDCARD; + } else if (wildcard.equals("*")) { + return ANY_WILDCARD; + } else { + return new SingleWildcard(wildcard); + } + } + + public FieldResolutionResult(Set regularFields, Set wildcards) { + this.regularFields = new HashSet<>(regularFields); + this.wildcard = createOrWildcard(wildcards); + } + + private static Wildcard createOrWildcard(Set patterns) { + if (patterns == null || patterns.isEmpty()) { + return NULL_WILDCARD; + } + if (patterns.size() == 1) { + return getWildcard(patterns.iterator().next()); + } + List wildcards = + patterns.stream().sorted().map(SingleWildcard::new).collect(Collectors.toList()); + return new OrWildcard(wildcards); + } + + public Set getRegularFieldsUnmodifiable() { + return Collections.unmodifiableSet(regularFields); + } + + public boolean hasWildcards() { + return wildcard != NULL_WILDCARD; + } + + public boolean hasRegularFields() { + return !regularFields.isEmpty(); + } + + public FieldResolutionResult exclude(Collection fields) { + Set combinedFields = new HashSet<>(this.regularFields); + combinedFields.removeAll(fields); + return new FieldResolutionResult(combinedFields, this.wildcard); + } + + public FieldResolutionResult or(Set fields) { + Set combinedFields = new HashSet<>(this.regularFields); + combinedFields.addAll(fields); + return new FieldResolutionResult(combinedFields, this.wildcard); + } + + private Set and(Set fields) { + return fields.stream() + .filter(field -> this.getRegularFields().contains(field) || this.wildcard.matches(field)) + .collect(Collectors.toSet()); + } + + public FieldResolutionResult and(FieldResolutionResult other) { + Set combinedFields = this.and(other.regularFields); + combinedFields.addAll(other.and(this.regularFields)); + + Wildcard combinedWildcard = this.wildcard.and(other.wildcard); + + return new FieldResolutionResult(combinedFields, combinedWildcard); + } + + /** Interface for wildcard pattern matching. */ + public interface Wildcard { + boolean matches(String fieldName); + + default Wildcard and(Wildcard other) { + return new AndWildcard(this, other); + } + + default Wildcard or(Wildcard other) { + return new OrWildcard(this, other); + } + } + + static Wildcard ANY_WILDCARD = + new Wildcard() { + @Override + public boolean matches(String fieldName) { + return true; + } + + @Override + public String toString() { + return "*"; + } + + @Override + public Wildcard and(Wildcard other) { + return other; + } + + @Override + public Wildcard or(Wildcard other) { + return this; + } + }; + + static Wildcard NULL_WILDCARD = + new Wildcard() { + public boolean matches(String fieldName) { + return false; + } + + @Override + public String toString() { + return ""; + } + + @Override + public Wildcard and(Wildcard other) { + return this; + } + + @Override + public Wildcard or(Wildcard other) { + return other; + } + }; + + /** Single wildcard pattern using '*' as wildcard character. */ + @Value + static class SingleWildcard implements Wildcard { + String pattern; + + @Override + public boolean matches(String fieldName) { + return WildcardUtils.matchesWildcardPattern(pattern, fieldName); + } + + @Override + public String toString() { + return pattern; + } + } + + /** OR combination of wildcard patterns (matches if ANY pattern matches). */ + @Value + static class OrWildcard implements Wildcard { + List patterns; + + public OrWildcard(Wildcard... patterns) { + this.patterns = List.of(patterns); + } + + public OrWildcard(Collection patterns) { + this.patterns = List.copyOf(patterns); + } + + @Override + public boolean matches(String fieldName) { + return patterns.stream().anyMatch(p -> p.matches(fieldName)); + } + + @Override + public String toString() { + return patterns.stream().map(Wildcard::toString).collect(Collectors.joining(" | ")); + } + + @Override + public Wildcard or(Wildcard other) { + if (other instanceof SingleWildcard) { + List newPatterns = + ImmutableList.builder() + .addAll(patterns) + .add(other) + .build(); + return new OrWildcard(newPatterns); + } else if (other == NULL_WILDCARD) { + return this; + } else if (other == ANY_WILDCARD) { + return ANY_WILDCARD; + } else { + return Wildcard.super.or(other); + } + } + } + + /** AND combination of wildcard patterns (matches if ALL patterns match). */ + @Value + static class AndWildcard implements Wildcard { + List patterns; + + public AndWildcard(Wildcard... patterns) { + this.patterns = List.of(patterns); + } + + public AndWildcard(Collection patterns) { + this.patterns = List.copyOf(patterns); + } + + @Override + public boolean matches(String fieldName) { + return patterns.stream().allMatch(p -> p.matches(fieldName)); + } + + @Override + public String toString() { + return patterns.stream() + .map(p -> "(" + p.toString() + ")") + .collect(Collectors.joining(" & ")); + } + + @Override + public Wildcard and(Wildcard other) { + if (other instanceof SingleWildcard) { + List newPatterns = + ImmutableList.builder() + .addAll(patterns) + .add(other) + .build(); + return new AndWildcard(newPatterns); + } else if (other == NULL_WILDCARD) { + return NULL_WILDCARD; + } else if (other == ANY_WILDCARD) { + return this; + } else { + return Wildcard.super.and(other); + } + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionVisitor.java b/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionVisitor.java new file mode 100644 index 00000000000..eff567ea498 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/analysis/FieldResolutionVisitor.java @@ -0,0 +1,624 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.analysis; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.function.UnaryOperator; +import java.util.stream.Collectors; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.AstNodeUtils; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.PatternMode; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.AddColTotals; +import org.opensearch.sql.ast.tree.AddTotals; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Append; +import org.opensearch.sql.ast.tree.AppendCol; +import org.opensearch.sql.ast.tree.AppendPipe; +import org.opensearch.sql.ast.tree.Bin; +import org.opensearch.sql.ast.tree.Chart; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.Expand; +import org.opensearch.sql.ast.tree.FillNull; +import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Flatten; +import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.ast.tree.Lookup; +import org.opensearch.sql.ast.tree.Multisearch; +import org.opensearch.sql.ast.tree.Parse; +import org.opensearch.sql.ast.tree.Patterns; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Regex; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Rename; +import org.opensearch.sql.ast.tree.Replace; +import org.opensearch.sql.ast.tree.Reverse; +import org.opensearch.sql.ast.tree.Rex; +import org.opensearch.sql.ast.tree.SPath; +import org.opensearch.sql.ast.tree.Search; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.StreamWindow; +import org.opensearch.sql.ast.tree.SubqueryAlias; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.ast.tree.Values; +import org.opensearch.sql.ast.tree.Window; +import org.opensearch.sql.calcite.utils.WildcardUtils; +import org.opensearch.sql.common.patterns.PatternUtils; +import org.opensearch.sql.expression.parse.PatternsExpression; +import org.opensearch.sql.expression.parse.RegexCommonUtils; + +/** + * Visitor to analyze and collect required fields from PPL AST using stack-based traversal. The + * result is used to support `spath` command (Field Resolution-based Extraction) and schema-on-read + * support in the future. + */ +public class FieldResolutionVisitor extends AbstractNodeVisitor { + + /** + * Analyzes PPL query plan to determine required fields at each node. + * + * @param plan root node of the PPL query plan + * @return map of plan nodes to their field requirements (regular fields and wildcard patterns) + * @throws IllegalArgumentException if plan contains unsupported commands or spath with wildcards + */ + public Map analyze(UnresolvedPlan plan) { + FieldResolutionContext context = new FieldResolutionContext(); + acceptAndVerifyNodeVisited(plan, context); + return context.getResults(); + } + + @Override + public Node visitChildren(Node node, FieldResolutionContext context) { + for (Node child : node.getChild()) { + acceptAndVerifyNodeVisited(child, context); + } + return null; + } + + /** + * Visit node and verify it returns same node. This ensures all the visit methods are implemented + * in this class. + */ + private void acceptAndVerifyNodeVisited(Node node, FieldResolutionContext context) { + Node result = node.accept(this, context); + if (result != node) { + throw new IllegalArgumentException( + "Unsupported command for field resolution: " + node.getClass().getSimpleName()); + } + } + + @Override + public Node visitProject(Project node, FieldResolutionContext context) { + boolean isSelectAll = + node.getProjectList().stream().anyMatch(expr -> expr instanceof AllFields); + + if (isSelectAll) { + visitChildren(node, context); + } else { + Set projectFields = new HashSet<>(); + Set wildcardPatterns = new HashSet<>(); + for (UnresolvedExpression expr : node.getProjectList()) { + extractFieldsFromExpression(expr) + .forEach( + field -> { + if (WildcardUtils.containsWildcard(field)) { + wildcardPatterns.add(field); + } else { + projectFields.add(field); + } + }); + } + + FieldResolutionResult current = context.getCurrentRequirements(); + context.pushRequirements( + current.and(new FieldResolutionResult(projectFields, wildcardPatterns))); + visitChildren(node, context); + context.popRequirements(); + } + return node; + } + + @Override + public Node visitFilter(Filter node, FieldResolutionContext context) { + Set filterFields = extractFieldsFromExpression(node.getCondition()); + if (AstNodeUtils.containsSubqueryExpression(node.getCondition())) { + // Does not support subquery as we cannot distinguish correl variable without static schema + throw new IllegalArgumentException( + "Filter by subquery is not supported with field resolution."); + } + + context.pushRequirements(context.getCurrentRequirements().or(filterFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitAggregation(Aggregation node, FieldResolutionContext context) { + Set aggFields = new HashSet<>(); + for (UnresolvedExpression groupExpr : node.getGroupExprList()) { + aggFields.addAll(extractFieldsFromExpression(groupExpr)); + } + if (node.getSpan() != null) { + aggFields.addAll(extractFieldsFromExpression(node.getSpan())); + } + for (UnresolvedExpression aggExpr : node.getAggExprList()) { + aggFields.addAll(extractFieldsFromAggregation(aggExpr)); + } + + context.pushRequirements(new FieldResolutionResult(aggFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitSpath(SPath node, FieldResolutionContext context) { + if (node.getPath() != null) { + return visitEval(node.rewriteAsEval(), context); + } else { + // set requirements for spath command; + context.setResult(node, context.getCurrentRequirements()); + FieldResolutionResult requirements = context.getCurrentRequirements(); + if (requirements.hasWildcards()) { + throw new IllegalArgumentException( + "Spath command cannot extract arbitrary fields. Please project fields explicitly by" + + " fields command without wildcard or stats command."); + } + + context.pushRequirements(context.getCurrentRequirements().or(Set.of(node.getInField()))); + visitChildren(node, context); + context.popRequirements(); + return node; + } + } + + @Override + public Node visitSort(Sort node, FieldResolutionContext context) { + Set sortFields = new HashSet<>(); + for (Field sortField : node.getSortList()) { + sortFields.addAll(extractFieldsFromExpression(sortField)); + } + + context.pushRequirements(context.getCurrentRequirements().or(sortFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitEval(Eval node, FieldResolutionContext context) { + Set evalInputFields = new HashSet<>(); + Set computedFields = new HashSet<>(); + + for (Let letExpr : node.getExpressionList()) { + evalInputFields.addAll(extractFieldsFromExpression(letExpr.getExpression())); + computedFields.add(letExpr.getVar().getField().toString()); + } + + FieldResolutionResult currentReq = context.getCurrentRequirements(); + Set allRequiredFields = new HashSet<>(currentReq.getRegularFields()); + allRequiredFields.removeAll(computedFields); + allRequiredFields.addAll(evalInputFields); + + context.pushRequirements( + new FieldResolutionResult(allRequiredFields, currentReq.getWildcard())); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + private Set extractFieldsFromExpression(UnresolvedExpression expr) { + Set fields = new HashSet<>(); + if (expr == null) { + return fields; + } + + if (expr instanceof Field field) { + fields.add(field.getField().toString()); + } else if (expr instanceof QualifiedName name) { + fields.add(name.toString()); + } else if (expr instanceof Alias alias) { + fields.addAll(extractFieldsFromExpression(alias.getDelegated())); + } else if (expr instanceof Function function) { + for (UnresolvedExpression arg : function.getFuncArgs()) { + fields.addAll(extractFieldsFromExpression(arg)); + } + } else if (expr instanceof Span span) { + fields.addAll(extractFieldsFromExpression(span.getField())); + } else if (expr instanceof Literal) { + return fields; + } else { + for (Node child : expr.getChild()) { + if (child instanceof UnresolvedExpression childExpr) { + fields.addAll(extractFieldsFromExpression(childExpr)); + } + } + } + return fields; + } + + @Override + public Node visitJoin(Join node, FieldResolutionContext context) { + Set joinFields = new HashSet<>(); + + if (node.getJoinCondition().isPresent()) { + joinFields.addAll(extractFieldsFromExpression(node.getJoinCondition().get())); + } + + if (node.getJoinFields().isPresent()) { + for (Field field : node.getJoinFields().get()) { + joinFields.addAll(extractFieldsFromExpression(field)); + } + } + + FieldResolutionResult currentReq = context.getCurrentRequirements(); + Set baseRequiredFields = new HashSet<>(currentReq.getRegularFields()); + + String leftAlias = node.getLeftAlias().orElse(null); + String rightAlias = node.getRightAlias().orElse(null); + + Set leftFields = collectFieldsByAlias(baseRequiredFields, leftAlias, rightAlias); + leftFields.addAll(collectFieldsByAlias(joinFields, leftAlias, rightAlias)); + + Set rightFields = collectFieldsByAlias(baseRequiredFields, rightAlias, leftAlias); + rightFields.addAll(collectFieldsByAlias(joinFields, rightAlias, leftAlias)); + + if (node.getLeft() != null) { + context.pushRequirements(new FieldResolutionResult(leftFields, currentReq.getWildcard())); + node.getLeft().accept(this, context); + context.popRequirements(); + } + + if (node.getRight() != null) { + context.pushRequirements(new FieldResolutionResult(rightFields, currentReq.getWildcard())); + node.getRight().accept(this, context); + context.popRequirements(); + } + + return node; + } + + /** + * Return lambda which remove alias from the input field, do nothing if the input does not start + * from the alias. + */ + private static UnaryOperator removeAlias(String alias) { + return (field) -> hasAlias(field, alias) ? field.substring(alias.length() + 1) : field; + } + + /** Return predicate to exclude the field which has the alias. */ + private static Predicate excludeAlias(String alias) { + return (field) -> !hasAlias(field, alias); + } + + private Set collectFieldsByAlias(Set fields, String alias, String excludedAlias) { + return fields.stream() + .filter(excludeAlias(excludedAlias)) + .map(removeAlias(alias)) + .collect(Collectors.toSet()); + } + + private static boolean hasAlias(String field, String alias) { + return alias != null && field.startsWith(alias + "."); + } + + @Override + public Node visitSubqueryAlias(SubqueryAlias node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitRelation(Relation node, FieldResolutionContext context) { + FieldResolutionResult currentReq = context.getCurrentRequirements(); + + context.setResult( + node, new FieldResolutionResult(currentReq.getRegularFields(), currentReq.getWildcard())); + return node; + } + + // Commands that don't modify field requirements - just pass through to children + @Override + public Node visitSearch(Search node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitAppendPipe(AppendPipe node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitRegex(Regex node, FieldResolutionContext context) { + Set regexFields = extractFieldsFromExpression(node.getField()); + context.pushRequirements(context.getCurrentRequirements().or(regexFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitRex(Rex node, FieldResolutionContext context) { + Set rexFields = extractFieldsFromExpression(node.getField()); + String patternStr = (String) node.getPattern().getValue(); + List namedGroups = RegexCommonUtils.getNamedGroupCandidates(patternStr); + + context.pushRequirements(context.getCurrentRequirements().exclude(namedGroups).or(rexFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitBin(Bin node, FieldResolutionContext context) { + Set binFields = extractFieldsFromExpression(node.getField()); + context.pushRequirements(context.getCurrentRequirements().or(binFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitParse(Parse node, FieldResolutionContext context) { + Set parseFields = extractFieldsFromExpression(node.getSourceField()); + context.pushRequirements(context.getCurrentRequirements().or(parseFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitPatterns(Patterns node, FieldResolutionContext context) { + Set patternFields = extractFieldsFromExpression(node.getSourceField()); + for (UnresolvedExpression partitionBy : node.getPartitionByList()) { + patternFields.addAll(extractFieldsFromExpression(partitionBy)); + } + Set addedFields = new HashSet<>(); + addedFields.add( + node.getAlias() != null ? node.getAlias() : PatternsExpression.DEFAULT_NEW_FIELD); + if (node.getPatternMode() == PatternMode.AGGREGATION) { + addedFields.add(PatternUtils.PATTERN_COUNT); + addedFields.add(PatternUtils.SAMPLE_LOGS); + } + if (node.getShowNumberedToken() != null) { + boolean showNumberedToken = Boolean.parseBoolean(node.getShowNumberedToken().toString()); + if (showNumberedToken) { + addedFields.add(PatternUtils.TOKENS); + } + } + + context.pushRequirements( + context.getCurrentRequirements().exclude(addedFields).or(patternFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitReverse(Reverse node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitHead(Head node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitRename(Rename node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitDedupe(Dedupe node, FieldResolutionContext context) { + Set dedupeFields = new HashSet<>(); + for (Field field : node.getFields()) { + dedupeFields.addAll(extractFieldsFromExpression(field)); + } + context.pushRequirements(context.getCurrentRequirements().or(dedupeFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitWindow(Window node, FieldResolutionContext context) { + Set windowFields = new HashSet<>(); + for (UnresolvedExpression windowFunc : node.getWindowFunctionList()) { + windowFields.addAll(extractFieldsFromExpression(windowFunc)); + } + if (node.getGroupList() != null) { + for (UnresolvedExpression groupExpr : node.getGroupList()) { + windowFields.addAll(extractFieldsFromExpression(groupExpr)); + } + } + context.pushRequirements(context.getCurrentRequirements().or(windowFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitStreamWindow(StreamWindow node, FieldResolutionContext context) { + Set streamWindowFields = new HashSet<>(); + for (UnresolvedExpression windowFunc : node.getWindowFunctionList()) { + streamWindowFields.addAll(extractFieldsFromExpression(windowFunc)); + } + if (node.getGroupList() != null) { + for (UnresolvedExpression groupExpr : node.getGroupList()) { + streamWindowFields.addAll(extractFieldsFromExpression(groupExpr)); + } + } + if (node.getResetBefore() != null) { + streamWindowFields.addAll(extractFieldsFromExpression(node.getResetBefore())); + } + if (node.getResetAfter() != null) { + streamWindowFields.addAll(extractFieldsFromExpression(node.getResetAfter())); + } + context.pushRequirements(context.getCurrentRequirements().or(streamWindowFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitFillNull(FillNull node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitAppendCol(AppendCol node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitAppend(Append node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitMultisearch(Multisearch node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitLookup(Lookup node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitValues(Values node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitReplace(Replace node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitFlatten(Flatten node, FieldResolutionContext context) { + Set flattenFields = extractFieldsFromExpression(node.getField()); + context.pushRequirements(context.getCurrentRequirements().or(flattenFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitTrendline(Trendline node, FieldResolutionContext context) { + Set trendlineFields = new HashSet<>(); + for (Trendline.TrendlineComputation computation : node.getComputations()) { + trendlineFields.addAll(extractFieldsFromExpression(computation.getDataField())); + } + if (node.getSortByField().isPresent()) { + trendlineFields.addAll(extractFieldsFromExpression(node.getSortByField().get())); + } + context.pushRequirements(context.getCurrentRequirements().or(trendlineFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitChart(Chart node, FieldResolutionContext context) { + Set chartFields = extractFieldsFromAggregation(node.getAggregationFunction()); + if (node.getRowSplit() != null) { + chartFields.addAll(extractFieldsFromExpression(node.getRowSplit())); + } + if (node.getColumnSplit() != null) { + chartFields.addAll(extractFieldsFromExpression(node.getColumnSplit())); + } + context.pushRequirements(new FieldResolutionResult(chartFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitRareTopN(RareTopN node, FieldResolutionContext context) { + Set rareTopNFields = new HashSet<>(); + for (Field field : node.getFields()) { + rareTopNFields.addAll(extractFieldsFromExpression(field)); + } + for (UnresolvedExpression groupExpr : node.getGroupExprList()) { + rareTopNFields.addAll(extractFieldsFromExpression(groupExpr)); + } + context.pushRequirements(new FieldResolutionResult(rareTopNFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + @Override + public Node visitAddTotals(AddTotals node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitAddColTotals(AddColTotals node, FieldResolutionContext context) { + visitChildren(node, context); + return node; + } + + @Override + public Node visitExpand(Expand node, FieldResolutionContext context) { + Set expandFields = extractFieldsFromExpression(node.getField()); + context.pushRequirements(context.getCurrentRequirements().or(expandFields)); + visitChildren(node, context); + context.popRequirements(); + return node; + } + + private Set extractFieldsFromAggregation(UnresolvedExpression expr) { + Set fields = new HashSet<>(); + if (expr instanceof Alias alias) { + return extractFieldsFromAggregation(alias.getDelegated()); + } else if (expr instanceof AggregateFunction aggFunc) { + if (aggFunc.getField() != null) { + fields.addAll(extractFieldsFromExpression(aggFunc.getField())); + } + if (aggFunc.getArgList() != null) { + for (UnresolvedExpression arg : aggFunc.getArgList()) { + fields.addAll(extractFieldsFromExpression(arg)); + } + } + } + return fields; + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java index 6ad935e59da..3cea56fe7a2 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java @@ -21,7 +21,10 @@ import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.tools.FrameworkConfig; +import org.opensearch.sql.ast.analysis.FieldResolutionResult; +import org.opensearch.sql.ast.analysis.FieldResolutionVisitor; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.calcite.utils.CalciteToolsHelper; import org.opensearch.sql.calcite.utils.CalciteToolsHelper.OpenSearchRelBuilder; import org.opensearch.sql.common.setting.Settings; @@ -72,6 +75,9 @@ public class CalcitePlanContext { /** Whether we're currently inside a lambda context. */ @Getter @Setter private boolean inLambdaContext = false; + /** Root node of the AST tree. Used for field resolution */ + @Setter private UnresolvedPlan rootNode; + private CalcitePlanContext(FrameworkConfig config, SysLimit sysLimit, QueryType queryType) { this.config = config; this.sysLimit = sysLimit; @@ -206,4 +212,24 @@ public RexLambdaRef captureVariable(RexNode fieldRef, String fieldName) { return lambdaRef; } + + /** + * Resolves required fields for a target node in the PPL query plan by analyzing the AST from + * root. Used for schema-on-read features like `spath` command. + * + * @param target the plan node to resolve field requirements for + * @return field resolution result with regular fields and wildcard patterns + * @throws IllegalStateException if root node not set via {@link #setRootNode} + */ + public FieldResolutionResult resolveFields(UnresolvedPlan target) { + if (rootNode == null) { + throw new IllegalStateException("Failed to resolve fields. Root node is not set."); + } + FieldResolutionVisitor visitor = new FieldResolutionVisitor(); + Map result = visitor.analyze(rootNode); + if (!result.containsKey(target)) { + throw new IllegalStateException("Failed to resolve fields for node: " + target.toString()); + } + return result.get(target); + } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 937a35b98cb..f1bc5fd6a0d 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -75,8 +75,9 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.AstNodeUtils; import org.opensearch.sql.ast.EmptySourcePropagateVisitor; -import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.analysis.FieldResolutionResult; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -86,7 +87,6 @@ import org.opensearch.sql.ast.expression.Argument.ArgumentMap; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.PatternMethod; @@ -97,7 +97,6 @@ import org.opensearch.sql.ast.expression.WindowFrame; import org.opensearch.sql.ast.expression.WindowFrame.FrameType; import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.subquery.SubqueryExpression; import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.AddColTotals; import org.opensearch.sql.ast.tree.AddTotals; @@ -236,7 +235,8 @@ public RelNode visitSearch(Search node, CalcitePlanContext context) { @Override public RelNode visitFilter(Filter node, CalcitePlanContext context) { visitChildren(node, context); - boolean containsSubqueryExpression = containsSubqueryExpression(node.getCondition()); + boolean containsSubqueryExpression = + AstNodeUtils.containsSubqueryExpression(node.getCondition()); final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); if (containsSubqueryExpression) { context.relBuilder.variable(v::set); @@ -365,24 +365,6 @@ public RelNode visitRex(Rex node, CalcitePlanContext context) { return context.relBuilder.peek(); } - private boolean containsSubqueryExpression(Node expr) { - if (expr == null) { - return false; - } - if (expr instanceof SubqueryExpression) { - return true; - } - if (expr instanceof Let l) { - return containsSubqueryExpression(l.getExpression()); - } - for (Node child : expr.getChild()) { - if (containsSubqueryExpression(child)) { - return true; - } - } - return false; - } - @Override public RelNode visitProject(Project node, CalcitePlanContext context) { visitChildren(node, context); @@ -721,7 +703,63 @@ public RelNode visitParse(Parse node, CalcitePlanContext context) { @Override public RelNode visitSpath(SPath node, CalcitePlanContext context) { - return visitEval(node.rewriteAsEval(), context); + if (node.getPath() != null) { + return visitEval(node.rewriteAsEval(), context); + } else { + return spathExtractAll(node, context); + } + } + + private RelNode spathExtractAll(SPath node, CalcitePlanContext context) { + visitChildren(node, context); + + FieldResolutionResult resolutionResult = context.resolveFields(node); + if (resolutionResult.hasWildcards()) { + // Logic for handling wildcards (dynamic fields) will be implemented later. + throw new IllegalArgumentException( + "spath command failed to identify fields to extract. Use fields/stats command to specify" + + " output fields."); + } + + // 1. Extract all fields from JSON in `inField` + RexNode inField = rexVisitor.analyze(AstDSL.field(node.getInField()), context); + RexNode map = makeCall(context, BuiltinFunctionName.JSON_EXTRACT_ALL, inField); + + // 2. Project items from FieldResolutionResult + Set existingFields = + new HashSet<>(context.relBuilder.peek().getRowType().getFieldNames()); + List fieldNames = + resolutionResult.getRegularFields().stream().sorted().collect(Collectors.toList()); + List fields = new ArrayList<>(); + for (String fieldName : fieldNames) { + RexNode item = itemCall(map, fieldName, context); + // Cast to string for type consistency. (This cast will be removed once functions are adopted + // to ANY type) + item = context.relBuilder.cast(item, SqlTypeName.VARCHAR); + // Append if field already exist + if (existingFields.contains(fieldName)) { + item = + makeCall( + context, + BuiltinFunctionName.INTERNAL_APPEND, + context.relBuilder.field(fieldName), + item); + } + fields.add(context.relBuilder.alias(item, fieldName)); + } + + context.relBuilder.project(fields); + return context.relBuilder.peek(); + } + + private RexNode itemCall(RexNode node, String key, CalcitePlanContext context) { + return makeCall( + context, BuiltinFunctionName.INTERNAL_ITEM, node, context.rexBuilder.makeLiteral(key)); + } + + private RexNode makeCall( + CalcitePlanContext context, BuiltinFunctionName functionName, RexNode... args) { + return PPLFuncImpTable.INSTANCE.resolve(context.rexBuilder, functionName, args); } @Override @@ -864,7 +902,7 @@ public RelNode visitEval(Eval node, CalcitePlanContext context) { node.getExpressionList() .forEach( expr -> { - boolean containsSubqueryExpression = containsSubqueryExpression(expr); + boolean containsSubqueryExpression = AstNodeUtils.containsSubqueryExpression(expr); final Holder<@Nullable RexCorrelVariable> v = Holder.empty(); if (containsSubqueryExpression) { context.relBuilder.variable(v::set); diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index 8edc3ad3f2c..b2c219b0e4e 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -252,6 +252,7 @@ public void executePlan( } public RelNode analyze(UnresolvedPlan plan, CalcitePlanContext context) { + context.setRootNode(plan); return getRelNodeVisitor().analyze(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index dce558bf7cc..50f88d47baf 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -75,6 +75,7 @@ public enum BuiltinFunctionName { MAP_CONCAT(FunctionName.of("map_concat"), true), MAP_REMOVE(FunctionName.of("map_remove"), true), MVAPPEND(FunctionName.of("mvappend")), + INTERNAL_APPEND(FunctionName.of("append"), true), MVJOIN(FunctionName.of("mvjoin")), MVINDEX(FunctionName.of("mvindex")), MVFIND(FunctionName.of("mvfind")), diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendCore.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/AppendCore.java similarity index 50% rename from core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendCore.java rename to core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/AppendCore.java index f9a67e4d6d8..1c3ef4efbe2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendCore.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/AppendCore.java @@ -8,14 +8,32 @@ import java.util.ArrayList; import java.util.List; -/** Core logic for `mvappend` command to collect elements from list of args */ -public class MVAppendCore { +/** + * Core logic for `mvappend` and internal `append` function to collect elements from list of args. + */ +public class AppendCore { /** * Collect non-null elements from `args`. If an item is a list, it will collect non-null elements - * of the list. See {@ref MVAppendFunctionImplTest} for detailed behavior. + * of the list. See {@link AppendFunctionImplTest} for detailed behavior. + */ + public static Object collectElements(Object... args) { + List elements = collectElementsToList(args); + + if (elements.isEmpty()) { + return null; + } else if (elements.size() == 1) { + // return the element in case of single element + return elements.get(0); + } else { + return elements; + } + } + + /** + * Collect non-null elements from `args`. If an item is a list, it will collect non-null elements. */ - public static List collectElements(Object... args) { + public static List collectElementsToList(Object... args) { List elements = new ArrayList<>(); for (Object arg : args) { @@ -28,7 +46,7 @@ public static List collectElements(Object... args) { } } - return elements.isEmpty() ? null : elements; + return elements; } private static void addListElements(List list, List elements) { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/AppendFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/AppendFunctionImpl.java new file mode 100644 index 00000000000..cecc19a14bb --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/AppendFunctionImpl.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import java.util.List; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.expression.function.ImplementorUDF; +import org.opensearch.sql.expression.function.UDFOperandMetadata; + +/** + * Internal append function that appends all elements from arguments to create an array. Returns + * null if there is no element. Returns the scalar value if there is single element. Otherwise, + * returns a list containing all the elements from inputs. + */ +public class AppendFunctionImpl extends ImplementorUDF { + + public AppendFunctionImpl() { + super(new AppendImplementor(), NullPolicy.ALL); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return sqlOperatorBinding -> { + RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory(); + + if (sqlOperatorBinding.getOperandCount() == 0) { + return typeFactory.createSqlType(SqlTypeName.NULL); + } + + // Return type is ANY as it could return scalar value (in case of single item) or array + return typeFactory.createSqlType(SqlTypeName.ANY); + }; + } + + @Override + public UDFOperandMetadata getOperandMetadata() { + return null; + } + + public static class AppendImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + return Expressions.call( + Types.lookupMethod(AppendFunctionImpl.class, "append", Object[].class), + Expressions.newArrayInit(Object.class, translatedOperands)); + } + } + + public static Object append(Object... args) { + return AppendCore.collectElements(args); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendFunctionImpl.java index 107df5eea4e..968176d3479 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MVAppendFunctionImpl.java @@ -97,6 +97,15 @@ public Expression implement( } public static Object mvappend(Object... args) { - return MVAppendCore.collectElements(args); + return collectElements(args); + } + + /** + * Collect non-null elements from `args`. If an item is a list, it will collect non-null elements + * of the list. See {@link MVAppendFunctionImplTest} for detailed behavior. + */ + public static List collectElements(Object... args) { + List elements = AppendCore.collectElementsToList(args); + return elements.isEmpty() ? null : elements; } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImpl.java index 4cb0acae612..c45e95288bb 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImpl.java @@ -24,8 +24,8 @@ import org.opensearch.sql.expression.function.UDFOperandMetadata; /** - * MapAppend function that merges two maps. All the values will be converted to list for type - * consistency. + * MapAppend function that merges two maps. Value for the same key will be merged into an array by + * using {@link AppendCore}. */ public class MapAppendFunctionImpl extends ImplementorUDF { @@ -89,7 +89,7 @@ private static Map verifyMap(Object map) { static Map mapAppendImpl(Map map) { Map result = new HashMap<>(); for (String key : map.keySet()) { - result.put(key, MVAppendCore.collectElements(map.get(key))); + result.put(key, AppendCore.collectElements(map.get(key))); } return result; } @@ -99,7 +99,7 @@ static Map mapAppendImpl( Map result = new HashMap<>(); for (String key : mergeKeys(firstMap, secondMap)) { - result.put(key, MVAppendCore.collectElements(firstMap.get(key), secondMap.get(key))); + result.put(key, AppendCore.collectElements(firstMap.get(key), secondMap.get(key))); } return result; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java index eaede1da7e4..16a27a57aac 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java @@ -8,7 +8,7 @@ import org.apache.commons.lang3.tuple.Pair; /** - * An interface for any class that can provide a {@ref FunctionBuilder} given a {@ref + * An interface for any class that can provide a {@link FunctionBuilder} given a {@link * FunctionSignature}. */ public interface FunctionResolver { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 2d769194924..3810352cbfd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -42,6 +42,7 @@ import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.datetime.DateTimeFunctions; +import org.opensearch.sql.expression.function.CollectionUDF.AppendFunctionImpl; import org.opensearch.sql.expression.function.CollectionUDF.ArrayFunctionImpl; import org.opensearch.sql.expression.function.CollectionUDF.ExistsFunctionImpl; import org.opensearch.sql.expression.function.CollectionUDF.FilterFunctionImpl; @@ -392,8 +393,9 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { public static final SqlOperator EXISTS = new ExistsFunctionImpl().toUDF("exists"); public static final SqlOperator ARRAY = new ArrayFunctionImpl().toUDF("array"); public static final SqlOperator MAP_APPEND = new MapAppendFunctionImpl().toUDF("map_append"); - public static final SqlOperator MAP_REMOVE = new MapRemoveFunctionImpl().toUDF("MAP_REMOVE"); + public static final SqlOperator MAP_REMOVE = new MapRemoveFunctionImpl().toUDF("map_remove"); public static final SqlOperator MVAPPEND = new MVAppendFunctionImpl().toUDF("mvappend"); + public static final SqlOperator INTERNAL_APPEND = new AppendFunctionImpl().toUDF("append"); public static final SqlOperator MVZIP = new MVZipFunctionImpl().toUDF("mvzip"); public static final SqlOperator MVFIND = new MVFindFunctionImpl().toUDF("mvfind"); public static final SqlOperator FILTER = new FilterFunctionImpl().toUDF("filter"); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 205f3a0f2e1..2d594c48f55 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -80,6 +80,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.IF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IFNULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ILIKE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_APPEND; import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_GROK; import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_ITEM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_PARSE; @@ -1039,6 +1040,7 @@ void populate() { registerOperator(ARRAY, PPLBuiltinOperators.ARRAY); registerOperator(MVAPPEND, PPLBuiltinOperators.MVAPPEND); + registerOperator(INTERNAL_APPEND, PPLBuiltinOperators.INTERNAL_APPEND); registerOperator(MVDEDUP, SqlLibraryOperators.ARRAY_DISTINCT); registerOperator(MVFIND, PPLBuiltinOperators.MVFIND); registerOperator(MVZIP, PPLBuiltinOperators.MVZIP); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImpl.java index 1f91c87bb77..d79ba14c3bf 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImpl.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Map; import java.util.Stack; +import java.util.stream.Collectors; import org.apache.calcite.adapter.enumerable.NotNullImplementor; import org.apache.calcite.adapter.enumerable.NullPolicy; import org.apache.calcite.adapter.enumerable.RexImpTable; @@ -25,6 +26,7 @@ import org.apache.calcite.linq4j.tree.Types; import org.apache.calcite.rex.RexCall; import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlReturnTypeInference; @@ -36,7 +38,7 @@ /** * UDF which extract all the fields from JSON to a MAP. Items are collected from input JSON and * stored with the key of their path in the JSON. This UDF is designed to be used for `spath` - * command without path param. See {@ref JsonExtractAllFunctionImplTest} for the detailed spec. + * command without path param. See {@link JsonExtractAllFunctionImplTest} for the detailed spec. */ public class JsonExtractAllFunctionImpl extends ImplementorUDF { private static final String ARRAY_SUFFIX = "{}"; @@ -57,7 +59,9 @@ public SqlReturnTypeInference getReturnTypeInference() { @Override public UDFOperandMetadata getOperandMetadata() { - return UDFOperandMetadata.wrap(OperandTypes.family(SqlTypeFamily.STRING)); + return UDFOperandMetadata.wrap( + (CompositeOperandTypeChecker) + OperandTypes.family(SqlTypeFamily.STRING).or(OperandTypes.family(SqlTypeFamily.ARRAY))); } public static class JsonExtractAllImplementor implements NotNullImplementor { @@ -77,12 +81,32 @@ public static Object eval(Object... args) { return null; } - String jsonStr = (String) args[0]; - if (jsonStr == null || jsonStr.trim().isEmpty()) { + String jsonStr = getString(args[0]); + return jsonStr != null ? convertEmptyMapToNull(parseJson(jsonStr)) : null; + } + + private static Map convertEmptyMapToNull(Map map) { + return (map == null || map.isEmpty()) ? null : map; + } + + private static String getString(Object input) { + if (input instanceof String) { + return (String) input; + } else if (input instanceof List) { + return convertArrayToString((List) input); + } + return null; + } + + private static String convertArrayToString(List array) { + if (array == null || array.isEmpty()) { return null; } - return parseJson(jsonStr); + return array.stream() + .filter(element -> element != null) + .map(Object::toString) + .collect(Collectors.joining()); } private static Map parseJson(String jsonStr) { diff --git a/core/src/test/java/org/opensearch/sql/ast/AstNodeUtilsTest.java b/core/src/test/java/org/opensearch/sql/ast/AstNodeUtilsTest.java new file mode 100644 index 00000000000..28089109e9e --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/ast/AstNodeUtilsTest.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.doReturn; + +import java.util.Arrays; +import java.util.Collections; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.subquery.ExistsSubquery; +import org.opensearch.sql.ast.expression.subquery.InSubquery; +import org.opensearch.sql.ast.expression.subquery.ScalarSubquery; +import org.opensearch.sql.ast.tree.UnresolvedPlan; + +@ExtendWith(MockitoExtension.class) +public class AstNodeUtilsTest { + + @Mock private UnresolvedPlan mockPlan; + @Mock private UnresolvedExpression mockExpr; + @Mock private Node mockNode; + @Mock private Node mockParent; + @Mock private Node mockChild; + @Mock private Node mockGrandchild; + + private ScalarSubquery scalarSubquery; + private ExistsSubquery existsSubquery; + private InSubquery inSubquery; + + @BeforeEach + public void setUp() { + scalarSubquery = new ScalarSubquery(mockPlan); + existsSubquery = new ExistsSubquery(mockPlan); + inSubquery = new InSubquery(Collections.singletonList(mockExpr), mockPlan); + } + + @Test + public void testContainsSubqueryExpressionWithNull() { + assertFalse(AstNodeUtils.containsSubqueryExpression(null)); + } + + @Test + public void testContainsSubqueryExpressionWithScalarSubquery() { + assertTrue(AstNodeUtils.containsSubqueryExpression(scalarSubquery)); + } + + @Test + public void testContainsSubqueryExpressionWithExistsSubquery() { + assertTrue(AstNodeUtils.containsSubqueryExpression(existsSubquery)); + } + + @Test + public void testContainsSubqueryExpressionWithInSubquery() { + assertTrue(AstNodeUtils.containsSubqueryExpression(inSubquery)); + } + + @Test + public void testContainsSubqueryExpressionWithLetContainingSubquery() { + Field field = new Field(QualifiedName.of("test")); + Let letExpr = new Let(field, scalarSubquery); + assertTrue(AstNodeUtils.containsSubqueryExpression(letExpr)); + } + + @Test + public void testContainsSubqueryExpressionWithLetNotContainingSubquery() { + Literal literal = new Literal(42, null); + Field field = new Field(QualifiedName.of("test")); + Let letExpr = new Let(field, literal); + assertFalse(AstNodeUtils.containsSubqueryExpression(letExpr)); + } + + @Test + public void testContainsSubqueryExpressionWithSimpleExpression() { + Literal literal = new Literal(42, null); + assertFalse(AstNodeUtils.containsSubqueryExpression(literal)); + } + + @Test + public void testContainsSubqueryExpressionWithNodeHavingSubqueryChild() { + doReturn(Collections.singletonList(scalarSubquery)).when(mockParent).getChild(); + assertTrue(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithNodeHavingNoSubqueryChild() { + Literal literal = new Literal(42, null); + doReturn(Collections.singletonList(literal)).when(mockParent).getChild(); + assertFalse(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithNodeHavingMultipleChildren() { + Literal literal1 = new Literal(1, null); + Literal literal2 = new Literal(2, null); + doReturn(Arrays.asList(literal1, scalarSubquery, literal2)).when(mockParent).getChild(); + assertTrue(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithNodeHavingNoChildren() { + doReturn(Collections.emptyList()).when(mockNode).getChild(); + assertFalse(AstNodeUtils.containsSubqueryExpression(mockNode)); + } + + @Test + public void testContainsSubqueryExpressionWithNestedStructure() { + doReturn(Collections.singletonList(scalarSubquery)).when(mockChild).getChild(); + doReturn(Collections.singletonList(mockChild)).when(mockParent).getChild(); + assertTrue(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithDeeplyNestedStructure() { + doReturn(Collections.singletonList(scalarSubquery)).when(mockGrandchild).getChild(); + doReturn(Collections.singletonList(mockGrandchild)).when(mockChild).getChild(); + doReturn(Collections.singletonList(mockChild)).when(mockParent).getChild(); + assertTrue(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithLetNestedInNode() { + Field field = new Field(QualifiedName.of("test")); + Let letExpr = new Let(field, scalarSubquery); + doReturn(Collections.singletonList(letExpr)).when(mockParent).getChild(); + assertTrue(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithMultipleLetExpressions() { + Field field1 = new Field(QualifiedName.of("test1")); + Let letWithSubquery = new Let(field1, scalarSubquery); + + Literal literal = new Literal(42, null); + Field field2 = new Field(QualifiedName.of("test2")); + Let letWithoutSubquery = new Let(field2, literal); + + doReturn(Arrays.asList(letWithoutSubquery, letWithSubquery)).when(mockParent).getChild(); + assertTrue(AstNodeUtils.containsSubqueryExpression(mockParent)); + } + + @Test + public void testContainsSubqueryExpressionWithComplexNestedLet() { + Field innerField = new Field(QualifiedName.of("inner")); + Let innerLet = new Let(innerField, scalarSubquery); + + Field outerField = new Field(QualifiedName.of("outer")); + Let outerLet = new Let(outerField, innerLet); + + assertTrue(AstNodeUtils.containsSubqueryExpression(outerLet)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/ast/analysis/FieldResolutionResultTest.java b/core/src/test/java/org/opensearch/sql/ast/analysis/FieldResolutionResultTest.java new file mode 100644 index 00000000000..31c009f76c1 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/ast/analysis/FieldResolutionResultTest.java @@ -0,0 +1,517 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.analysis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Set; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.analysis.FieldResolutionResult.AndWildcard; +import org.opensearch.sql.ast.analysis.FieldResolutionResult.OrWildcard; +import org.opensearch.sql.ast.analysis.FieldResolutionResult.SingleWildcard; +import org.opensearch.sql.ast.analysis.FieldResolutionResult.Wildcard; + +class FieldResolutionResultTest { + + @Test + void testSingleWildcardMatching() { + Wildcard wildcard = new SingleWildcard("user*"); + + assertTrue(wildcard.matches("user")); + assertTrue(wildcard.matches("username")); + assertTrue(wildcard.matches("user_id")); + assertFalse(wildcard.matches("admin")); + assertFalse(wildcard.matches("name")); + } + + @Test + void testSingleWildcardWithMultipleWildcards() { + Wildcard wildcard = new SingleWildcard("*_id_*"); + + assertTrue(wildcard.matches("user_id_123")); + assertTrue(wildcard.matches("_id_")); + assertTrue(wildcard.matches("prefix_id_suffix")); + assertFalse(wildcard.matches("user_id")); + assertFalse(wildcard.matches("id_123")); + } + + @Test + void testSingleWildcardExactMatch() { + Wildcard wildcard = new SingleWildcard("exact_field"); + + assertTrue(wildcard.matches("exact_field")); + assertFalse(wildcard.matches("exact_field_suffix")); + assertFalse(wildcard.matches("prefix_exact_field")); + } + + @Test + void testSingleWildcardToString() { + Wildcard wildcard = new SingleWildcard("user*"); + assertEquals("user*", wildcard.toString()); + } + + @Test + void testSingleWildcardEquality() { + Wildcard w1 = new SingleWildcard("user*"); + Wildcard w2 = new SingleWildcard("user*"); + Wildcard w3 = new SingleWildcard("admin*"); + + assertEquals(w1, w2); + assertEquals(w1.hashCode(), w2.hashCode()); + assertNotEquals(w1, w3); + } + + @Test + void testOrWildcardMatching() { + Wildcard wildcard = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + + assertTrue(wildcard.matches("username")); + assertTrue(wildcard.matches("admin_role")); + assertFalse(wildcard.matches("guest")); + } + + @Test + void testOrWildcardWithSinglePattern() { + Wildcard wildcard = new OrWildcard(new SingleWildcard("user*")); + + assertTrue(wildcard.matches("username")); + assertFalse(wildcard.matches("admin")); + } + + @Test + void testOrWildcardToString() { + Wildcard wildcard = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + + assertEquals("user* | admin*", wildcard.toString()); + } + + @Test + void testOrWildcardEquality() { + Wildcard w1 = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + Wildcard w2 = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + Wildcard w3 = new OrWildcard(new SingleWildcard("guest*"), new SingleWildcard("admin*")); + + assertEquals(w1, w2); + assertEquals(w1.hashCode(), w2.hashCode()); + assertNotEquals(w1, w3); + } + + @Test + void testAndWildcardMatching() { + Wildcard wildcard = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + + assertTrue(wildcard.matches("username")); + assertTrue(wildcard.matches("user_full_name")); + assertFalse(wildcard.matches("user_id")); + assertFalse(wildcard.matches("admin_name")); + } + + @Test + void testAndWildcardWithSinglePattern() { + Wildcard wildcard = new AndWildcard(new SingleWildcard("user*")); + + assertTrue(wildcard.matches("username")); + assertFalse(wildcard.matches("admin")); + } + + @Test + void testAndWildcardToString() { + Wildcard wildcard = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + + assertEquals("(*user*) & (*name)", wildcard.toString()); + } + + @Test + void testAndWildcardEquality() { + Wildcard w1 = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + Wildcard w2 = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + Wildcard w3 = new AndWildcard(new SingleWildcard("*admin*"), new SingleWildcard("*name")); + + assertEquals(w1, w2); + assertEquals(w1.hashCode(), w2.hashCode()); + assertNotEquals(w1, w3); + } + + @Test + void testNestedWildcardCombinations() { + Wildcard or1 = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + Wildcard or2 = new OrWildcard(new SingleWildcard("*_id"), new SingleWildcard("*_name")); + Wildcard and = new AndWildcard(or1, or2); + + assertTrue(and.matches("user_id")); + assertTrue(and.matches("admin_name")); + assertFalse(and.matches("user_role")); + assertFalse(and.matches("guest_id")); + } + + @Test + void testNestedWildcardToString() { + Wildcard or1 = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + Wildcard or2 = new OrWildcard(new SingleWildcard("*_id"), new SingleWildcard("*_name")); + Wildcard and = new AndWildcard(or1, or2); + + assertEquals("(user* | admin*) & (*_id | *_name)", and.toString()); + } + + @Test + void testFieldResolutionResultWithNoWildcard() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1", "field2")); + + assertFalse(result.hasWildcards()); + assertTrue(result.hasRegularFields()); + assertEquals(2, result.getRegularFields().size()); + assertTrue(result.getWildcard().toString().isEmpty()); + } + + @Test + void testFieldResolutionResultWithEmptyFields() { + FieldResolutionResult result = new FieldResolutionResult(Set.of()); + + assertFalse(result.hasWildcards()); + assertFalse(result.hasRegularFields()); + assertEquals(0, result.getRegularFields().size()); + } + + @Test + void testFieldResolutionResultWithNullOrEmptyWildcardString() { + FieldResolutionResult result1 = new FieldResolutionResult(Set.of("field1"), (String) null); + FieldResolutionResult result2 = new FieldResolutionResult(Set.of("field1"), ""); + + assertTrue(result1.hasRegularFields()); + assertFalse(result1.hasWildcards()); + assertTrue(result1.getWildcard().toString().isEmpty()); + + assertTrue(result2.hasRegularFields()); + assertFalse(result2.hasWildcards()); + assertTrue(result2.getWildcard().toString().isEmpty()); + } + + @Test + void testFieldResolutionResultWithAnyWildcard() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1"), "*"); + + assertTrue(result.hasWildcards()); + assertTrue(result.getWildcard().matches("anything")); + assertTrue(result.getWildcard().matches("field1")); + assertTrue(result.getWildcard().matches("")); + assertEquals("*", result.getWildcard().toString()); + } + + @Test + void testFieldResolutionResultWithNullOrEmptyWildcardSet() { + FieldResolutionResult result1 = new FieldResolutionResult(Set.of("field1"), Set.of()); + FieldResolutionResult result2 = new FieldResolutionResult(Set.of("field1"), (Set) null); + + assertTrue(result1.hasRegularFields()); + assertFalse(result1.hasWildcards()); + + assertTrue(result2.hasRegularFields()); + assertFalse(result2.hasWildcards()); + } + + @Test + void testGetRegularFieldsUnmodifiable() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1", "field2")); + Set unmodifiable = result.getRegularFieldsUnmodifiable(); + + assertEquals(2, unmodifiable.size()); + assertTrue(unmodifiable.contains("field1")); + assertTrue(unmodifiable.contains("field2")); + + try { + unmodifiable.add("field3"); + assertTrue(false, "Should throw UnsupportedOperationException"); + } catch (UnsupportedOperationException e) { + // Expected + } + } + + @Test + void testFieldResolutionResultEquality() { + FieldResolutionResult result1 = new FieldResolutionResult(Set.of("field1", "field2"), "user*"); + FieldResolutionResult result2 = new FieldResolutionResult(Set.of("field1", "field2"), "user*"); + FieldResolutionResult result3 = new FieldResolutionResult(Set.of("field1"), "user*"); + FieldResolutionResult result4 = new FieldResolutionResult(Set.of("field1", "field2"), "admin*"); + + assertEquals(result1, result2); + assertEquals(result1.hashCode(), result2.hashCode()); + assertNotEquals(result1, result3); + assertNotEquals(result1, result4); + } + + @Test + void testFieldResolutionResultToString() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1", "field2"), "user*"); + String str = result.toString(); + + assertTrue(str.contains("field1")); + assertTrue(str.contains("field2")); + assertTrue(str.contains("user*")); + } + + @Test + void testFieldResolutionResultWithStringWildcard() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1"), "user*"); + + assertTrue(result.hasWildcards()); + assertTrue(result.hasRegularFields()); + assertTrue(result.getWildcard() instanceof SingleWildcard); + assertEquals("user*", result.getWildcard().toString()); + } + + @Test + void testFieldResolutionResultWithWildcardObject() { + Wildcard wildcard = new SingleWildcard("admin*"); + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1"), wildcard); + + assertTrue(result.hasWildcards()); + assertEquals(wildcard, result.getWildcard()); + } + + @Test + void testFieldResolutionResultWithMultipleWildcardPatterns() { + FieldResolutionResult result = + new FieldResolutionResult(Set.of("field1"), Set.of("user*", "admin*")); + + assertTrue(result.hasWildcards()); + assertTrue(result.getWildcard() instanceof OrWildcard); + assertEquals("admin* | user*", result.getWildcard().toString()); + } + + @Test + void testFieldResolutionResultWithSingleWildcardPattern() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1"), Set.of("user*")); + + assertTrue(result.hasWildcards()); + assertTrue(result.getWildcard() instanceof SingleWildcard); + assertEquals("user*", result.getWildcard().toString()); + } + + @Test + void testFieldResolutionResultAndOperation() { + FieldResolutionResult result1 = + new FieldResolutionResult(Set.of("user_id", "user_name", "admin_id"), "user*"); + FieldResolutionResult result2 = + new FieldResolutionResult(Set.of("user_id", "admin_id"), "*_id"); + + FieldResolutionResult combined = result1.and(result2); + + assertEquals(Set.of("user_id", "admin_id"), combined.getRegularFields()); + assertTrue(combined.hasWildcards()); + assertTrue(combined.getWildcard() instanceof AndWildcard); + assertEquals("(user*) & (*_id)", combined.getWildcard().toString()); + } + + @Test + void testFieldResolutionResultAndWithWildcardMatching() { + FieldResolutionResult result1 = + new FieldResolutionResult(Set.of("user_id", "user_name", "admin_role"), "user*"); + FieldResolutionResult result2 = new FieldResolutionResult(Set.of("admin_id"), "*_id"); + + FieldResolutionResult combined = result1.and(result2); + + assertTrue(combined.getRegularFields().contains("user_id")); + assertFalse(combined.getRegularFields().contains("user_name")); + assertFalse(combined.getRegularFields().contains("admin_role")); + } + + @Test + void testFieldResolutionResultAndWithNullWildcards() { + FieldResolutionResult result1 = new FieldResolutionResult(Set.of("user_id", "admin_id")); + FieldResolutionResult result2 = new FieldResolutionResult(Set.of("user_id"), "user*"); + + FieldResolutionResult combined = result1.and(result2); + + assertEquals(Set.of("user_id"), combined.getRegularFields()); + assertFalse(combined.hasWildcards()); + } + + @Test + void testFieldResolutionResultAndWithBothNullWildcards() { + FieldResolutionResult result1 = new FieldResolutionResult(Set.of("user_id", "admin_id")); + FieldResolutionResult result2 = new FieldResolutionResult(Set.of("user_id")); + + FieldResolutionResult combined = result1.and(result2); + + assertEquals(Set.of("user_id"), combined.getRegularFields()); + assertFalse(combined.hasWildcards()); + } + + @Test + void testFieldResolutionResultOrOperation() { + FieldResolutionResult result = new FieldResolutionResult(Set.of("field1"), "user*"); + FieldResolutionResult updated = result.or(Set.of("field2", "field3")); + + assertEquals(Set.of("field1", "field2", "field3"), updated.getRegularFields()); + assertEquals("user*", updated.getWildcard().toString()); + } + + @Test + void testFieldResolutionResultExcludeOperation() { + FieldResolutionResult result = + new FieldResolutionResult(Set.of("field1", "field2", "field3"), "user*"); + FieldResolutionResult updated = result.exclude(Set.of("field2")); + + assertEquals(Set.of("field1", "field3"), updated.getRegularFields()); + assertEquals("user*", updated.getWildcard().toString()); + } + + @Test + void testWildcardAndWithAnyWildcard() { + Wildcard single = new SingleWildcard("user*"); + + Wildcard result = FieldResolutionResult.ANY_WILDCARD.and(single); + assertEquals(single, result); + assertTrue(result.matches("username")); + assertFalse(result.matches("admin")); + } + + @Test + void testWildcardAndWithNullWildcard() { + Wildcard single = new SingleWildcard("user*"); + + Wildcard result = FieldResolutionResult.NULL_WILDCARD.and(single); + assertEquals(FieldResolutionResult.NULL_WILDCARD, result); + assertFalse(result.matches("username")); + } + + @Test + void testWildcardOrWithAnyWildcard() { + Wildcard single = new SingleWildcard("user*"); + + Wildcard result = FieldResolutionResult.ANY_WILDCARD.or(single); + assertEquals(FieldResolutionResult.ANY_WILDCARD, result); + assertTrue(result.matches("anything")); + } + + @Test + void testWildcardOrWithNullWildcard() { + Wildcard single = new SingleWildcard("user*"); + + Wildcard result = FieldResolutionResult.NULL_WILDCARD.or(single); + assertEquals(single, result); + assertTrue(result.matches("username")); + assertFalse(result.matches("admin")); + } + + @Test + void testOrWildcardOrWithSingleWildcard() { + OrWildcard or = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + SingleWildcard single = new SingleWildcard("guest*"); + + Wildcard result = or.or(single); + assertTrue(result instanceof OrWildcard); + assertTrue(result.matches("username")); + assertTrue(result.matches("admin_role")); + assertTrue(result.matches("guest_id")); + assertFalse(result.matches("other")); + } + + @Test + void testAndWildcardAndWithSingleWildcard() { + AndWildcard and = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + SingleWildcard single = new SingleWildcard("*full*"); + + Wildcard result = and.and(single); + assertTrue(result instanceof AndWildcard); + assertTrue(result.matches("user_full_name")); + assertFalse(result.matches("username")); + assertFalse(result.matches("user_id")); + } + + @Test + void testOrWildcardOrWithOrWildcard() { + OrWildcard or1 = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + OrWildcard or2 = new OrWildcard(new SingleWildcard("guest*"), new SingleWildcard("root*")); + + Wildcard result = or1.or(or2); + assertTrue(result instanceof OrWildcard); + assertTrue(result.matches("username")); + assertTrue(result.matches("admin_role")); + assertTrue(result.matches("guest_id")); + assertTrue(result.matches("root_access")); + assertFalse(result.matches("other")); + } + + @Test + void testOrWildcardOrWithAndWildcard() { + OrWildcard or = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + AndWildcard and = new AndWildcard(new SingleWildcard("*_id"), new SingleWildcard("guest*")); + + Wildcard result = or.or(and); + assertTrue(result instanceof OrWildcard); + assertTrue(result.matches("username")); + assertTrue(result.matches("admin_role")); + assertTrue(result.matches("guest_id")); + assertFalse(result.matches("guest_name")); + } + + @Test + void testAndWildcardAndWithAndWildcard() { + AndWildcard and1 = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + AndWildcard and2 = new AndWildcard(new SingleWildcard("*full*"), new SingleWildcard("user*")); + + Wildcard result = and1.and(and2); + assertTrue(result instanceof AndWildcard); + assertTrue(result.matches("user_full_name")); + assertFalse(result.matches("username")); + assertFalse(result.matches("user_name")); + assertFalse(result.matches("full_name")); + } + + @Test + void testAndWildcardAndWithOrWildcard() { + AndWildcard and = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + OrWildcard or = new OrWildcard(new SingleWildcard("*full*"), new SingleWildcard("*first*")); + + Wildcard result = and.and(or); + assertTrue(result instanceof AndWildcard); + assertTrue(result.matches("user_full_name")); + assertTrue(result.matches("user_first_name")); + assertFalse(result.matches("username")); + assertFalse(result.matches("user_last_name")); + } + + @Test + void testOrWildcardOrWithNullWildcard() { + OrWildcard or = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + + Wildcard result = or.or(FieldResolutionResult.NULL_WILDCARD); + assertEquals(or, result); + assertTrue(result.matches("username")); + assertTrue(result.matches("admin_role")); + } + + @Test + void testOrWildcardOrWithAnyWildcard() { + OrWildcard or = new OrWildcard(new SingleWildcard("user*"), new SingleWildcard("admin*")); + + Wildcard result = or.or(FieldResolutionResult.ANY_WILDCARD); + assertEquals(FieldResolutionResult.ANY_WILDCARD, result); + assertTrue(result.matches("anything")); + } + + @Test + void testAndWildcardAndWithNullWildcard() { + AndWildcard and = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + + Wildcard result = and.and(FieldResolutionResult.NULL_WILDCARD); + assertEquals(FieldResolutionResult.NULL_WILDCARD, result); + assertFalse(result.matches("username")); + } + + @Test + void testAndWildcardAndWithAnyWildcard() { + AndWildcard and = new AndWildcard(new SingleWildcard("*user*"), new SingleWildcard("*name")); + + Wildcard result = and.and(FieldResolutionResult.ANY_WILDCARD); + assertEquals(and, result); + assertTrue(result.matches("username")); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/AppendFunctionImplTest.java b/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/AppendFunctionImplTest.java new file mode 100644 index 00000000000..cbbfa67e8a3 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/AppendFunctionImplTest.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class AppendFunctionImplTest { + + @Test + public void testAppendWithNoArguments() { + Object result = AppendFunctionImpl.append(); + assertNull(result); + } + + @Test + public void testAppendWithSingleElement() { + Object result = AppendFunctionImpl.append(42); + assertEquals(42, result); + } + + @Test + public void testAppendWithMultipleElements() { + Object result = AppendFunctionImpl.append(1, 2, 3); + assertEquals(Arrays.asList(1, 2, 3), result); + } + + @Test + public void testAppendWithNullValues() { + Object result = AppendFunctionImpl.append(null, 1, null); + assertEquals(1, result); + } + + @Test + public void testAppendWithAllNulls() { + Object result = AppendFunctionImpl.append(null, null); + assertNull(result); + } + + @Test + public void testAppendWithArrayFlattening() { + List array1 = Arrays.asList(1, 2); + List array2 = Arrays.asList(3, 4); + Object result = AppendFunctionImpl.append(array1, array2); + assertEquals(Arrays.asList(1, 2, 3, 4), result); + } + + @Test + public void testAppendWithMixedTypes() { + List array = Arrays.asList(1, 2); + Object result = AppendFunctionImpl.append(array, 3, "hello"); + assertEquals(Arrays.asList(1, 2, 3, "hello"), result); + } + + @Test + public void testAppendWithArrayAndNulls() { + List array = Arrays.asList(1, 2); + Object result = AppendFunctionImpl.append(null, array, null, 3); + assertEquals(Arrays.asList(1, 2, 3), result); + } + + @Test + public void testAppendWithSingleNull() { + Object result = AppendFunctionImpl.append((Object) null); + assertNull(result); + } + + @Test + public void testAppendWithEmptyArray() { + List emptyArray = Arrays.asList(); + Object result = AppendFunctionImpl.append(emptyArray, 1); + assertEquals(1, result); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImplTest.java b/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImplTest.java index 28df3c3cbd1..6b131aed86b 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImplTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/CollectionUDF/MapAppendFunctionImplTest.java @@ -5,8 +5,8 @@ package org.opensearch.sql.expression.function.CollectionUDF; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; import java.util.HashMap; import java.util.List; @@ -21,11 +21,14 @@ void testMapAppendWithNonOverlappingKeys() { Map result = MapAppendFunctionImpl.mapAppendImpl(map1, map2); - assertEquals(4, result.size()); - assertMapListValues(result, "a", "value1"); - assertMapListValues(result, "b", "value2"); - assertMapListValues(result, "c", "value3"); - assertMapListValues(result, "d", "value4"); + assertThat( + result, + allOf( + hasEntry("a", "value1"), + hasEntry("b", "value2"), + hasEntry("c", "value3"), + hasEntry("d", "value4"), + aMapWithSize(4))); } @Test @@ -35,10 +38,13 @@ void testMapAppendWithOverlappingKeys() { Map result = MapAppendFunctionImpl.mapAppendImpl(map1, map2); - assertEquals(3, result.size()); - assertMapListValues(result, "a", "value1"); - assertMapListValues(result, "b", "value2", "value3"); - assertMapListValues(result, "c", "value4"); + assertThat( + result, + allOf( + aMapWithSize(3), + hasEntry("a", (Object) "value1"), + hasEntry("b", (Object) List.of("value2", "value3")), + hasEntry("c", (Object) "value4"))); } @Test @@ -48,10 +54,13 @@ void testMapAppendWithArrayValues() { Map result = MapAppendFunctionImpl.mapAppendImpl(map1, map2); - assertEquals(3, result.size()); - assertMapListValues(result, "a", "item1", "item2", "item3"); - assertMapListValues(result, "b", "single"); - assertMapListValues(result, "c", "item4", "item5"); + assertThat( + result, + allOf( + aMapWithSize(3), + hasEntry("a", (Object) List.of("item1", "item2", "item3")), + hasEntry("b", (Object) "single"), + hasEntry("c", (Object) List.of("item4", "item5")))); } @Test @@ -64,11 +73,14 @@ void testMapAppendWithNullValues() { Map result = MapAppendFunctionImpl.mapAppendImpl(map1, map2); - assertEquals(4, result.size()); - assertMapListValues(result, "a", "value1"); - assertMapListValues(result, "b", "value2"); - assertMapListValues(result, "c", "value3"); - assertMapListValues(result, "d", "value4"); + assertThat( + result, + allOf( + hasEntry("a", "value1"), + hasEntry("b", "value2"), + hasEntry("c", "value3"), + hasEntry("d", "value4"), + aMapWithSize(4))); } @Test @@ -77,9 +89,7 @@ void testMapAppendWithSingleParam() { Map result = MapAppendFunctionImpl.mapAppendImpl(map1); - assertEquals(2, result.size()); - assertMapListValues(result, "a", "value1"); - assertMapListValues(result, "b", "value2"); + assertThat(result, allOf(hasEntry("a", "value1"), hasEntry("b", "value2"), aMapWithSize(2))); } private Map getMap1() { @@ -95,14 +105,4 @@ private Map getMap2() { map2.put("d", "value4"); return map2; } - - private void assertMapListValues(Map map, String key, Object... expectedValues) { - Object val = map.get(key); - assertTrue(val instanceof List); - List result = (List) val; - assertEquals(expectedValues.length, result.size()); - for (int i = 0; i < expectedValues.length; i++) { - assertEquals(expectedValues[i], result.get(i)); - } - } } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImplTest.java b/core/src/test/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImplTest.java index 5a010a17422..af14453767e 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImplTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/jsonUDF/JsonExtractAllFunctionImplTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; @@ -71,39 +72,27 @@ public void testFunctionConstructor() { assertNotNull(testFunction, "Function should be properly initialized"); } - @Test - public void testNoArguments() { - Object result = JsonExtractAllFunctionImpl.eval(); - - assertNull(result); + private void assertEvalNull(Object... args) { + assertNull(JsonExtractAllFunctionImpl.eval(args)); } @Test - public void testNullInput() { - Object result = JsonExtractAllFunctionImpl.eval((String) null); + public void testNoArguments() { + Object result = JsonExtractAllFunctionImpl.eval(); assertNull(result); } @Test public void testEmptyString() { - Object result = JsonExtractAllFunctionImpl.eval(""); - - assertNull(result); - } - - @Test - public void testWhitespaceString() { - Object result = JsonExtractAllFunctionImpl.eval(" "); - - assertNull(result); - } - - @Test - public void testEmptyJsonObject() { - Map map = eval("{}"); - - assertTrue(map.isEmpty()); + assertEvalNull(); + assertEvalNull((String) null); + assertEvalNull(""); + assertEvalNull("", ""); + assertEvalNull(" "); + assertEvalNull("{}"); + assertEvalNull("\"just a string\""); + assertEvalNull("123"); } @Test @@ -149,20 +138,6 @@ public void testTopLevelArrayOfComplexObjects() { assertEquals(2, map.size()); } - @Test - public void testNonObjectJsonPrimitive() { - Object result = JsonExtractAllFunctionImpl.eval("\"just a string\""); - - assertNull(result); - } - - @Test - public void testNonObjectJsonNumber() { - Object result = JsonExtractAllFunctionImpl.eval("42"); - - assertNull(result); - } - @Test public void testSingleLevelNesting() { Map map = eval("{\"user\": {\"name\": \"John\"}, \"system\": \"linux\"}"); @@ -242,10 +217,7 @@ public void testNested() { @Test public void testEmptyArray() { - Map map = eval("{\"empty\": []}"); - - Object emptyValue = map.get("empty{}"); - assertNull(emptyValue); + assertEvalNull("{\"empty\": []}"); } @Test @@ -357,4 +329,44 @@ public void testLargeJsonObject() { assertEquals(0, map.get("field0")); assertEquals(99, map.get("field99")); } + + @Test + public void testArrayInputWithSingleJsonObject() { + List array = Arrays.asList("{\"name\": \"John\", \"age\": 30}"); + Object result = JsonExtractAllFunctionImpl.eval(array); + Map map = assertValidMapResult(result); + + assertEquals("John", map.get("name")); + assertEquals(30, map.get("age")); + assertEquals(2, map.size()); + } + + @Test + public void testArrayInputWithMultipleJsonFragments() { + List array = Arrays.asList("{\"name\": \"John\"", ", \"age\": 30}"); + Object result = JsonExtractAllFunctionImpl.eval(array); + Map map = assertValidMapResult(result); + + assertEquals("John", map.get("name")); + assertEquals(30, map.get("age")); + assertEquals(2, map.size()); + } + + @Test + public void testArrayInputWithNullElements() { + List array = Arrays.asList("{\"name\": ", null, "\"John\", \"age\": 30}"); + Object result = JsonExtractAllFunctionImpl.eval(array); + Map map = assertValidMapResult(result); + + assertEquals("John", map.get("name")); + assertEquals(30, map.get("age")); + assertEquals(2, map.size()); + } + + @Test + public void testNullAndEmptyArray() { + assertEvalNull(Arrays.asList(null, null, null)); + assertEvalNull(Arrays.asList()); + assertEvalNull((List) null); + } } diff --git a/docs/dev/ppl-commands.md b/docs/dev/ppl-commands.md index ea727e234a5..5c3538883a9 100644 --- a/docs/dev/ppl-commands.md +++ b/docs/dev/ppl-commands.md @@ -27,6 +27,7 @@ If you are working on contributing a new PPL command, please read this guide and - [ ] **Visitor Pattern:** - Add `visit*` in `AbstractNodeVisitor` - Overriding `visit*` in `Analyzer`, `CalciteRelNodeVisitor` and `PPLQueryDataAnonymizer` + - Override `visit*` in `FieldResolutionVisitor` for `spath` command support. - [ ] **Unit Tests:** - Extend `CalcitePPLAbstractTest` diff --git a/docs/user/ppl/cmd/spath.md b/docs/user/ppl/cmd/spath.md index d9293113fb0..1e63a8b01e7 100644 --- a/docs/user/ppl/cmd/spath.md +++ b/docs/user/ppl/cmd/spath.md @@ -1,18 +1,26 @@ - # spath -The `spath` command extracts fields from structured text data by allowing you to select JSON values using JSON paths. +The `spath` command extracts fields from structured JSON data. It supports two modes: + +1. **Path-based extraction**: Extract specific fields using JSON paths +2. **Field resolution-based extraction**: Extract multiple fields automatically based on downstream field requirements > **Note**: The `spath` command is not executed on OpenSearch data nodes. It extracts fields from data after it has been returned to the coordinator node, which is slow on large datasets. We recommend indexing fields needed for filtering directly instead of using `spath` to filter nested fields. ## Syntax -The `spath` command has the following syntax: +### Path-based Extraction ```syntax spath input= [output=] [path=] ``` +### Field Resolution-based Extraction (Experimental) + +```syntax +spath input= +``` + ## Parameters The `spath` command supports the following parameters. @@ -20,11 +28,19 @@ The `spath` command supports the following parameters. | Parameter | Required/Optional | Description | | --- | --- | --- | | `input` | Required | The field containing JSON data to parse. | -| `output` | Optional | The destination field in which the extracted data is stored. Default is the value of ``. | -| `` | Required | The JSON path that identifies the data to extract. | +| `output` | Optional | The destination field in which the extracted data is stored. Default is the value of ``. Only used in path-based extraction. | +| `` | Required for path-based extraction | The JSON path that identifies the data to extract. | For more information about path syntax, see [json_extract](../functions/json.md#json_extract). +### Field Resolution-based Extraction Notes + +* Extracts only required fields based on downstream commands requirements (interim solution until full fields extraction is implemented) +* **Limitation**: It raises error if extracted fields cannot be identified by following commands (i.e. `fields`, or `stats` command is needed) +* **Limitation**: Cannot use wildcards (`*`) in field selection - only explicit field names are supported +* **Limitation**: All extracted fields are returned as STRING type +* **Limitation**: Filter with query (`where in/exists [...]` ) is not supported after `spath` command + ## Example 1: Basic field extraction The basic use of `spath` extracts a single field from JSON data. The following query extracts the `n` field from JSON objects in the `doc_n` field: @@ -122,4 +138,121 @@ fetched rows / total rows = 3/3 | false | 2 | +-------+---+ ``` - + +## Example 5: Field Resolution-based Extraction + +Extract multiple fields automatically based on downstream requirements. The `spath` command analyzes which fields are needed and extracts only those fields. + +```ppl +source=structured +| eval c = 1 +| spath input=doc_multi +| fields doc_multi, a, b, c +``` + +Expected output: + +```text +fetched rows / total rows = 3/3 ++--------------------------------------+----+----+--------+ +| doc_multi | a | b | c | +|--------------------------------------+----+----+--------| +| {"a": 10, "b": 20, "c": 30, "d": 40} | 10 | 20 | [1,30] | +| {"a": 15, "b": 25, "c": 35, "d": 45} | 15 | 25 | [1,35] | +| {"a": 11, "b": 21, "c": 31, "d": 41} | 11 | 21 | [1,31] | ++--------------------------------------+----+----+--------+ +``` + +This extracts only fields `a`, `b`, and `c` from the JSON in `doc_multi` field, even though the JSON contains fields `d` as well. All extracted fields are returned as STRING type. As `c` in the example, extracted value is appended to organize an array if an extracted field already exists. + +## Example 6: Field Merge with Dotted Names + +When a JSON document contains both a direct field with a dotted name and a nested object path that resolves to the same field name, `spath` merges both values into an array. + +```ppl +source=structured +| spath input=doc_dotted +| fields doc_dotted, a.b +| head 1 +``` + +Expected output: + +```text +fetched rows / total rows = 1/1 ++---------------------------+--------+ +| doc_dotted | a.b | +|---------------------------+--------| +| {"a.b": 1, "a": {"b": 2}} | [1, 2] | ++---------------------------+--------+ +``` + +In this example, the JSON contains both `"a.b": 1` (direct field with dot) and `"a": {"b": 2}` (nested path). The `spath` command extracts both values and merges them into the array `[1, 2]`. + +## Example 7: Field Resolution with Eval + +This example shows field resolution with computed fields. The `spath` command extracts only the fields needed by downstream commands. + +```ppl +source=structured +| spath input=doc_multi +| eval sum_ab = cast(a as int) + cast(b as int) +| fields doc_multi, a, b, sum_ab +``` + +Expected output: + +```text +fetched rows / total rows = 3/3 ++--------------------------------------+----+----+--------+ +| doc_multi | a | b | sum_ab | +|--------------------------------------+----+----+--------| +| {"a": 10, "b": 20, "c": 30, "d": 40} | 10 | 20 | 30 | +| {"a": 15, "b": 25, "c": 35, "d": 45} | 15 | 25 | 40 | +| {"a": 11, "b": 21, "c": 31, "d": 41} | 11 | 21 | 32 | ++--------------------------------------+----+----+--------+ +``` + +The `spath` command extracts only fields `a` and `b` (needed by the `eval` command), which are then cast to integers and summed. Fields `c` and `d` are not extracted since they're not needed. + +## Example 8: Field Resolution with Stats + +This example demonstrates field resolution with aggregation. The `spath` command extracts only the fields needed for grouping and aggregation. + +```ppl +source=structured +| spath input=doc_multi +| stats avg(cast(a as int)) as avg_a, sum(cast(b as int)) as sum_b by c +``` + +Expected output: + +```text +fetched rows / total rows = 3/3 ++-------+-------+----+ +| avg_a | sum_b | c | +|-------+-------+----| +| 10.0 | 20 | 30 | +| 11.0 | 21 | 31 | +| 15.0 | 25 | 35 | ++-------+-------+----+ +``` + +The `spath` command extracts fields `a`, `b`, and `c` (needed by the `stats` command for aggregation and grouping). Field `d` is not extracted since it's not used. + +## Example 9: Field Resolution Limitations + +**Important**: It raises error if extracted fields cannot be identified by following commands + +```ppl +source=structured +| spath input=doc_multi +| eval x = a * b # ERROR: Requires field selection (fields or stats command) +``` + +**Important**: Wildcards are not supported in field resolution mode: + +```ppl +source=structured +| spath input=doc_multi +| fields a, b* # ERROR: Spath command cannot extract arbitrary fields diff --git a/doctest/test_data/structured.json b/doctest/test_data/structured.json index c0717c6f328..a44c9d95cd4 100644 --- a/doctest/test_data/structured.json +++ b/doctest/test_data/structured.json @@ -1,3 +1,3 @@ -{"doc_n":"{\"n\": 1}","doc_escape":"{\"a fancy field name\": true,\"a.b.c\": 0}","doc_list":"{\"list\": [1, 2, 3, 4], \"nest_out\": {\"nest_in\": \"a\"}}","obj_field":{"field": "a"}} -{"doc_n":"{\"n\": 2}","doc_escape":"{\"a fancy field name\": true,\"a.b.c\": 1}","doc_list":"{\"list\": [], \"nest_out\": {\"nest_in\": \"a\"}}","obj_field":{"field": "b"}} -{"doc_n":"{\"n\": 3}","doc_escape":"{\"a fancy field name\": false,\"a.b.c\": 2}","doc_list":"{\"list\": [5, 6], \"nest_out\": {\"nest_in\": \"a\"}}","obj_field":{"field": "c"}} \ No newline at end of file +{"doc_n":"{\"n\": 1}","doc_escape":"{\"a fancy field name\": true,\"a.b.c\": 0}","doc_list":"{\"list\": [1, 2, 3, 4], \"nest_out\": {\"nest_in\": \"a\"}}","doc_multi":"{\"a\": 10, \"b\": 20, \"c\": 30, \"d\": 40}","doc_dotted":"{\"a.b\": 1, \"a\": {\"b\": 2}}","obj_field":{"field": "a"}} +{"doc_n":"{\"n\": 2}","doc_escape":"{\"a fancy field name\": true,\"a.b.c\": 1}","doc_list":"{\"list\": [], \"nest_out\": {\"nest_in\": \"a\"}}","doc_multi":"{\"a\": 15, \"b\": 25, \"c\": 35, \"d\": 45}","doc_dotted":"{\"a.b\": 1, \"a\": {\"b\": 2}}","obj_field":{"field": "b"}} +{"doc_n":"{\"n\": 3}","doc_escape":"{\"a fancy field name\": false,\"a.b.c\": 2}","doc_list":"{\"list\": [5, 6], \"nest_out\": {\"nest_in\": \"a\"}}","doc_multi":"{\"a\": 11, \"b\": 21, \"c\": 31, \"d\": 41}","doc_dotted":"{\"a.b\": 1, \"a\": {\"b\": 2}}","obj_field":{"field": "c"}} diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java index c254fb47c44..22a6f6b5916 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java @@ -84,6 +84,7 @@ CalcitePPLRenameIT.class, CalcitePPLScalarSubqueryIT.class, CalcitePPLSortIT.class, + CalcitePPLSpathCommandIT.class, CalcitePPLStringBuiltinFunctionIT.class, CalcitePPLTrendlineIT.class, CalcitePrometheusDataSourceCommandsIT.class, diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index b6cd327989a..95db518ceb7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -2321,6 +2321,13 @@ public void testNotBetweenPushDownExplain() throws Exception { "source=opensearch-sql_test_index_bank | where age not between 30 and 39")); } + @Test + public void testSpathWithoutPathExplain() throws IOException { + String expected = loadExpectedPlan("explain_spath_without_path.yaml"); + assertYamlEqualsIgnoreId( + expected, explainQueryYaml(source(TEST_INDEX_LOGS, "spath input=message | fields test"))); + } + @Test public void testExplainInVariousModeAndFormat() throws IOException { enabledOnlyWhenPushdownIsEnabled(); diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLSpathCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLSpathCommandIT.java index 51b5bd40304..6431590b096 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLSpathCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLSpathCommandIT.java @@ -5,6 +5,7 @@ package org.opensearch.sql.calcite.remote; +import static org.opensearch.sql.util.MatcherUtils.array; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; @@ -17,32 +18,155 @@ import org.opensearch.sql.ppl.PPLIntegTestCase; public class CalcitePPLSpathCommandIT extends PPLIntegTestCase { + private static final String INDEX = "test_spath"; + @Override public void init() throws Exception { super.init(); enableCalcite(); - loadIndex(Index.BANK); + putItem(1, "simple", sj("{'a': 1, 'b': 2, 'c': 3}")); + putItem(2, "simple", sj("{'a': 1, 'b': 2, 'c': 3}")); + putItem(3, "nested", sj("{'nested': {'d': [1, 2, 3], 'e': 'str'}}")); + putItem(4, "join1", sj("{'key': 'k1', 'left': 'l'}")); + putItem(5, "join2", sj("{'key': 'k1', 'right': 'r1'}")); + putItem(6, "join2", sj("{'key': 'k2', 'right': 'r2'}")); + putItem(7, "overwrap", sj("{'a.b': 1, 'a': {'b': 2, 'c': 3}}")); + } - // Create test data for string concatenation - Request request1 = new Request("PUT", "/test_spath/_doc/1?refresh=true"); - request1.setJsonEntity("{\"doc\": \"{\\\"n\\\": 1}\"}"); - client().performRequest(request1); + private void putItem(int id, String testCase, String json) throws Exception { + Request request = new Request("PUT", String.format("/%s/_doc/%d?refresh=true", INDEX, id)); + request.setJsonEntity(docWithJson(testCase, json)); + client().performRequest(request); + } - Request request2 = new Request("PUT", "/test_spath/_doc/2?refresh=true"); - request2.setJsonEntity("{\"doc\": \"{\\\"n\\\": 2}\"}"); - client().performRequest(request2); + private String docWithJson(String testCase, String json) { + return String.format(sj("{'testCase': '%s', 'doc': '%s'}"), testCase, escape(json)); + } - Request request3 = new Request("PUT", "/test_spath/_doc/3?refresh=true"); - request3.setJsonEntity("{\"doc\": \"{\\\"n\\\": 3}\"}"); - client().performRequest(request3); + private String escape(String json) { + return json.replace("\"", "\\\""); + } + + private String sj(String singleQuoteJson) { + return singleQuoteJson.replace("'", "\""); } @Test public void testSimpleSpath() throws IOException { JSONObject result = - executeQuery("source=test_spath | spath input=doc output=result path=n | fields result"); + executeQuery( + "source=test_spath | where testCase='simple' | spath input=doc output=result path=a |" + + " fields result | head 2"); verifySchema(result, schema("result", "string")); - verifyDataRows(result, rows("1"), rows("2"), rows("3")); + verifyDataRows(result, rows("1"), rows("1")); + } + + private static final String EXPECTED_ARBITRARY_FIELD_ERROR = + "Spath command cannot extract arbitrary fields. " + + "Please project fields explicitly by fields command without wildcard or stats command."; + + @Test + public void testSpathWithoutFields() throws IOException { + verifyExplainException( + "source=test_spath | spath input=doc | eval a = 1", EXPECTED_ARBITRARY_FIELD_ERROR); + } + + @Test + public void testSpathWithWildcard() throws IOException { + verifyExplainException( + "source=test_spath | spath input=doc | fields a, b*", EXPECTED_ARBITRARY_FIELD_ERROR); + } + + private static final String EXPECTED_SUBQUERY_ERROR = + "Filter by subquery is not supported with field resolution."; + + @Test + public void testSpathWithSubsearch() throws IOException { + verifyExplainException( + "source=test_spath | spath input=doc | where b in [source=test_spath | fields a] | fields" + + " b", + EXPECTED_SUBQUERY_ERROR); + } + + @Test + public void testSpathWithFields() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='simple' | spath input=doc | fields a, b, c | head" + + " 1"); + verifySchema(result, schema("a", "string"), schema("b", "string"), schema("c", "string")); + verifyDataRows(result, rows("1", "2", "3")); + } + + @Test + public void testSpathWithAbsentField() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='simple' | spath input=doc | fields a, x | head 1"); + verifySchema(result, schema("a", "string"), schema("x", "string")); + verifyDataRows(result, rows("1", null)); + } + + @Test + public void testOverwrap() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='overwrap' | spath input=doc | fields a.b | head" + + " 1"); + verifySchema(result, schema("a.b", "string")); + verifyDataRows(result, rows("[1, 2]")); + } + + @Test + public void testSpathTwice() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='simple' | spath input=doc | spath input=doc |" + + " fields a, doc | head 1"); + verifySchema(result, schema("a", "array"), schema("doc", "string")); + verifyDataRows(result, rows(array("1", "1"), sj("{'a': 1, 'b': 2, 'c': 3}"))); + } + + @Test + public void testSpathWithEval() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='simple' | spath input=doc |" + + " eval result = a * b * c | fields result | head 1"); + verifySchema(result, schema("result", "double")); + verifyDataRows(result, rows(6)); + } + + @Test + public void testSpathWithStats() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='simple' | spath input=doc |" + + "stats count by a, b | head 1"); + verifySchema(result, schema("count", "bigint"), schema("a", "string"), schema("b", "string")); + verifyDataRows(result, rows(2, "1", "2")); + } + + @Test + public void testSpathWithNestedFields() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='nested' | spath input=doc | fields `nested.d{}`," + + " nested.e"); + verifySchema(result, schema("nested.d{}", "string"), schema("nested.e", "string")); + verifyDataRows(result, rows("[1, 2, 3]", "str")); + } + + @Test + public void testSpathWithJoin() throws IOException { + JSONObject result = + executeQuery( + "source=test_spath | where testCase='join1' | spath input=doc | fields key, left | join" + + " key [source=test_spath | where testCase='join2' | spath input=doc | fields key," + + " right ] |fields key, left, right"); + verifySchema( + result, schema("key", "string"), schema("left", "string"), schema("right", "string")); + verifyDataRows(result, rows("k1", "l", "r1")); } } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/JsonExtractAllFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/JsonExtractAllFunctionIT.java index 68bf57ea8dd..f2683e91e10 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/JsonExtractAllFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/JsonExtractAllFunctionIT.java @@ -246,8 +246,7 @@ public void testJsonExtractAllWithEmptyObject() throws Exception { assertTrue(resultSet.next()); verifyColumns(resultSet, RESULT_FIELD); - Map map = getMap(resultSet, 1); - assertTrue(map.isEmpty()); + assertNull(resultSet.getObject(1)); }); } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/MapAppendFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/MapAppendFunctionIT.java index a77c3270e3b..e231b8ae68f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/MapAppendFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/MapAppendFunctionIT.java @@ -5,6 +5,8 @@ package org.opensearch.sql.calcite.standalone; +import static org.hamcrest.Matchers.*; + import java.sql.ResultSet; import java.sql.SQLException; import java.util.List; @@ -41,11 +43,14 @@ public void testMapAppendWithNonOverlappingKeys() throws Exception { relNode, resultSet -> { Map result = getResultMapField(resultSet); - assertEquals(4, result.size()); - assertMapListValue(result, "key1", "value1"); - assertMapListValue(result, "key2", "value2"); - assertMapListValue(result, "key3", "value3"); - assertMapListValue(result, "key4", "value4"); + assertThat( + result, + allOf( + hasEntry("key1", "value1"), + hasEntry("key2", "value2"), + hasEntry("key3", "value3"), + hasEntry("key4", "value4"), + aMapWithSize(4))); }); } @@ -68,10 +73,13 @@ public void testMapAppendWithOverlappingKeys() throws Exception { relNode, resultSet -> { Map result = getResultMapField(resultSet); - assertEquals(3, result.size()); - assertMapListValue(result, "key1", "value1"); - assertMapListValue(result, "key2", "value2", "value3"); - assertMapListValue(result, "key3", "value4"); + assertThat( + result, + allOf( + aMapWithSize(3), + hasEntry("key1", (Object) "value1"), + hasEntry("key2", (Object) List.of("value2", "value3")), + hasEntry("key3", (Object) "value4"))); }); } @@ -100,8 +108,7 @@ private void testWithSingleNull(RexNode map1, RexNode map2) throws Exception { relNode, resultSet -> { Map result = getResultMapField(resultSet); - assertEquals(1, result.size()); - assertMapListValue(result, "key1", "value1"); + assertThat(result, allOf(hasEntry("key1", "value1"), aMapWithSize(1))); }); } @@ -143,16 +150,4 @@ private Map getResultMapField(ResultSet resultSet) throws SQLExc Map result = (Map) resultSet.getObject(1); return result; } - - @SuppressWarnings("unchecked") - private void assertMapListValue(Map map, String key, Object... expectedValues) { - map.containsKey(key); - Object value = map.get(key); - assertTrue(value instanceof List); - List list = (List) value; - assertEquals(expectedValues.length, list.size()); - for (int i = 0; i < expectedValues.length; i++) { - assertEquals(expectedValues[i], list.get(i)); - } - } } diff --git a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java index e41c85a8e80..aea7d1a7d99 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/MatcherUtils.java @@ -301,6 +301,10 @@ protected boolean matchesSafely(JSONArray array) { }; } + public static JSONArray array(Object... objects) { + return new JSONArray(objects); + } + public static TypeSafeMatcher closeTo(Object... values) { final double error = 1e-10; return new TypeSafeMatcher() { diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_spath_without_path.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_spath_without_path.yaml new file mode 100644 index 00000000000..36ced78d6db --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_spath_without_path.yaml @@ -0,0 +1,8 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(test=[CAST(ITEM(JSON_EXTRACT_ALL($3), 'test')):VARCHAR NOT NULL]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]]) + physical: | + EnumerableCalc(expr#0=[{inputs}], expr#1=[JSON_EXTRACT_ALL($t0)], expr#2=['test'], expr#3=[ITEM($t1, $t2)], expr#4=[CAST($t3):VARCHAR NOT NULL], test=[$t4]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]], PushDownContext=[[PROJECT->[message], LIMIT->10000], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":10000,"timeout":"1m","_source":{"includes":["message"],"excludes":[]}}, requestedTotalSize=10000, pageSize=null, startFrom=0)]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_spath_without_path.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_spath_without_path.yaml new file mode 100644 index 00000000000..0f47f3015d2 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_spath_without_path.yaml @@ -0,0 +1,9 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(test=[CAST(ITEM(JSON_EXTRACT_ALL($3), 'test')):VARCHAR NOT NULL]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..10=[{inputs}], expr#11=[JSON_EXTRACT_ALL($t3)], expr#12=['test'], expr#13=[ITEM($t11, $t12)], expr#14=[CAST($t13):VARCHAR NOT NULL], test=[$t14]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_logs]]) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java index 93d07a83843..8f3f337ffce 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java @@ -56,7 +56,7 @@ OpenSearchResponse search( /** * Check if there is more data to get from OpenSearch. * - * @return True if calling {@ref OpenSearchClient.search} with this request will return non-empty + * @return True if calling {@link OpenSearchClient.search} with this request will return non-empty * response. */ boolean hasAnotherBatch(); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 3f4f3049365..ffac1159e52 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -905,8 +905,10 @@ public UnresolvedPlan visitSpathCommand(OpenSearchPPLParser.SpathCommandContext if (inField == null) { throw new IllegalArgumentException("`input` parameter is required for `spath`"); } - if (path == null) { - throw new IllegalArgumentException("`path` parameter is required for `spath`"); + + if (outField != null && path == null) { + throw new IllegalArgumentException( + "`path` parameter is required for `spath` when `output` is specified"); } return new SPath(inField, outField, path); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/FieldResolutionVisitorCoverageTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/FieldResolutionVisitorCoverageTest.java new file mode 100644 index 00000000000..f39d80adc37 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/FieldResolutionVisitorCoverageTest.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.parser; + +import static org.junit.Assert.assertTrue; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import org.junit.Test; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.analysis.FieldResolutionVisitor; +import org.opensearch.sql.calcite.CalciteRelNodeVisitor; + +/** + * Test to verify that FieldResolutionVisitor overrides all methods that CalciteRelNodeVisitor + * overrides from AbstractNodeVisitor. + * + *

This ensures that FieldResolutionVisitor provides field resolution logic for all AST node + * types that CalciteRelNodeVisitor handles. + */ +public class FieldResolutionVisitorCoverageTest { + + @Test + public void testFieldResolutionVisitorOverridesAllCalciteRelNodeVisitorMethods() { + Set calciteOverriddenMethods = getOverriddenMethods(CalciteRelNodeVisitor.class); + Set fieldResolutionOverriddenMethods = + getOverriddenMethods(FieldResolutionVisitor.class); + + // Find methods that CalciteRelNodeVisitor overrides but FieldResolutionVisitor doesn't + Set missingMethods = new HashSet<>(calciteOverriddenMethods); + missingMethods.removeAll(fieldResolutionOverriddenMethods); + + // Only allow unsupported Calcite commands to be missing + Set unsupportedCalciteCommands = + Set.of( + "visitAD", + "visitCloseCursor", + "visitFetchCursor", + "visitML", + "visitPaginate", + "visitKmeans", + "visitTableFunction"); + + missingMethods.removeAll(unsupportedCalciteCommands); + + assertTrue( + "FieldResolutionVisitor must override all supported methods that CalciteRelNodeVisitor " + + "overrides. Missing methods: " + + missingMethods, + missingMethods.isEmpty()); + } + + private Set getOverriddenMethods(Class clazz) { + Set abstractMethods = getAbstractNodeVisitorMethods(); + return Arrays.stream(clazz.getDeclaredMethods()) + .filter(method -> abstractMethods.contains(method.getName())) + .map(Method::getName) + .collect(Collectors.toSet()); + } + + private Set getAbstractNodeVisitorMethods() { + return Arrays.stream(AbstractNodeVisitor.class.getDeclaredMethods()) + .filter(method -> method.getName().startsWith("visit")) + .map(Method::getName) + .collect(Collectors.toSet()); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/FieldResolutionVisitorTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/FieldResolutionVisitorTest.java new file mode 100644 index 00000000000..c7d8de43964 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/FieldResolutionVisitorTest.java @@ -0,0 +1,355 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.parser; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.opensearch.sql.ast.analysis.FieldResolutionResult; +import org.opensearch.sql.ast.analysis.FieldResolutionVisitor; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; + +public class FieldResolutionVisitorTest { + + private final FieldResolutionVisitor visitor = new FieldResolutionVisitor(); + private final PPLSyntaxParser parser = new PPLSyntaxParser(); + private Settings settings; + + @Before + public void setUp() { + settings = Mockito.mock(Settings.class); + when(settings.getSettingValue(Settings.Key.PPL_REX_MAX_MATCH_LIMIT)).thenReturn(10); + } + + private UnresolvedPlan parse(String query) { + AstBuilder astBuilder = new AstBuilder(query, settings); + return astBuilder.visit(parser.parse(query)); + } + + private FieldResolutionResult getSingleRelationResult(String query) { + UnresolvedPlan plan = parse(query); + Map results = visitor.analyze(plan); + UnresolvedPlan relation = results.keySet().iterator().next(); + return results.get(relation); + } + + private void assertSingleRelationFields( + String query, Set expectedFields, String expectedWildcard) { + FieldResolutionResult result = getSingleRelationResult(query); + assertEquals(expectedFields, result.getRegularFields()); + assertEquals(expectedWildcard, result.getWildcard().toString()); + } + + private void assertJoinRelationFields( + String query, Map expectedResultsByTable) { + UnresolvedPlan plan = parse(query); + Map results = visitor.analyze(plan); + + assertEquals(expectedResultsByTable.size(), results.size()); + + for (Map.Entry entry : results.entrySet()) { + if (!(entry.getKey() instanceof Relation)) { + continue; + } + String tableName = ((Relation) entry.getKey()).getTableQualifiedName().toString(); + FieldResolutionResult expectedResult = expectedResultsByTable.get(tableName); + + if (expectedResult != null) { + assertEquals(expectedResult.getRegularFields(), entry.getValue().getRegularFields()); + assertEquals(expectedResult.getWildcard(), entry.getValue().getWildcard()); + } + } + } + + @Test + public void testSimpleRelation() { + assertSingleRelationFields("source=logs", Set.of(), "*"); + } + + @Test + public void testFilterOnly() { + assertSingleRelationFields("source=logs | where status > 200", Set.of("status"), "*"); + } + + @Test + public void testMultipleFilters() { + assertSingleRelationFields( + "source=logs | where status > 200 AND region = 'us-west'", Set.of("status", "region"), "*"); + } + + @Test + public void testProjectOnly() { + assertSingleRelationFields( + "source=logs | fields status, region", Set.of("status", "region"), ""); + } + + @Test + public void testMultipleProject() { + assertSingleRelationFields( + "source=logs | fields status, region, *age | fields message, st*", + Set.of("status", "message"), + "(st*) & (*age)"); + } + + @Test + public void testFilterThenProject() { + assertSingleRelationFields( + "source=logs | where status > 200 | fields region", Set.of("region", "status"), ""); + } + + @Test + public void testAggregationWithGroupBy() { + assertSingleRelationFields("source=logs | stats count() by region", Set.of("region"), ""); + } + + @Test + public void testAggregationWithFieldAndGroupBy() { + assertSingleRelationFields( + "source=logs | stats avg(response_time) by region", Set.of("region", "response_time"), ""); + } + + @Test + public void testComplexQuery() { + assertSingleRelationFields( + "source=logs | where status > 200 | stats count() by region", + Set.of("region", "status"), + ""); + } + + @Test + public void testSortCommand() { + assertSingleRelationFields("source=logs | sort status", Set.of("status"), "*"); + } + + @Test + public void testEvalCommand() { + assertSingleRelationFields( + "source=logs | eval new_field = old_field + 1", Set.of("old_field"), "*"); + } + + @Test + public void testEvalThenFilter() { + assertSingleRelationFields( + "source=logs | eval doubled = value * 2 | where doubled > 100", Set.of("value"), "*"); + } + + @Test + public void testNestedFields() { + assertSingleRelationFields( + "source=logs | where `user.name` = 'john'", Set.of("user.name"), "*"); + } + + @Test + public void testFunctionInFilter() { + assertSingleRelationFields("source=logs | where length(message) > 100", Set.of("message"), "*"); + } + + @Test + public void testMultipleAggregations() { + assertSingleRelationFields( + "source=logs | stats count(), avg(response_time), max(bytes) by region, status", + Set.of("region", "status", "response_time", "bytes"), + ""); + } + + @Test + public void testComplexNestedQuery() { + assertSingleRelationFields( + "source=logs | where status > 200 AND region = 'us-west' " + + "| eval response_ms = response_time * 1000 " + + "| stats avg(response_ms), max(bytes) by region, status " + + "| sort region", + Set.of("status", "region", "response_time", "bytes"), + ""); + } + + @Test + public void testWildcardPatternMerging() { + assertSingleRelationFields( + "source=logs | fields `prefix*`, `prefix_sub*`", Set.of(), "prefix* | prefix_sub*"); + } + + @Test + public void testSingleWildcardPattern() { + assertSingleRelationFields("source=logs | fields `prefix*`", Set.of(), "prefix*"); + } + + @Test + public void testWildcardWithRegularFields() { + assertSingleRelationFields( + "source=logs | fields status, `prefix*`, region", Set.of("status", "region"), "prefix*"); + } + + @Test + public void testMultiRelationResult() { + UnresolvedPlan plan = parse("source=logs | where status > 200"); + Map results = visitor.analyze(plan); + + assertEquals(1, results.size()); + + UnresolvedPlan relation = results.keySet().iterator().next(); + assertTrue(relation instanceof Relation); + assertEquals("logs", ((Relation) relation).getTableQualifiedName().toString()); + + FieldResolutionResult result = results.get(relation); + assertEquals(Set.of("status"), result.getRegularFields()); + assertEquals("*", result.getWildcard().toString()); + } + + @Test + public void testSimpleJoin() { + assertJoinRelationFields( + "source=logs1 | join left=l right=r ON l.id = r.id logs2", + Map.of( + "logs1", new FieldResolutionResult(Set.of("id"), "*"), + "logs2", new FieldResolutionResult(Set.of("id"), "*"))); + } + + @Test + public void testJoinWithFilter() { + assertJoinRelationFields( + "source=logs1 | where status > 200 | join left=l right=r ON l.id = r.id logs2", + Map.of( + "logs1", new FieldResolutionResult(Set.of("status", "id"), "*"), + "logs2", new FieldResolutionResult(Set.of("id"), "*"))); + } + + @Test + public void testJoinWithProject() { + assertJoinRelationFields( + "source=logs1 | join left=l right=r ON l.id = r.id logs2 | fields l.name, r.value", + Map.of( + "logs1", new FieldResolutionResult(Set.of("name", "id")), + "logs2", new FieldResolutionResult(Set.of("value", "id")))); + } + + @Test + public void testJoinWithNestedFields() { + assertJoinRelationFields( + "source=logs1 | join left=l right=r ON l.id = r.id logs2 | fields l.name, r.value, field," + + " nested.field", + Map.of( + "logs1", new FieldResolutionResult(Set.of("name", "id", "field", "nested.field")), + "logs2", new FieldResolutionResult(Set.of("value", "id", "field", "nested.field")))); + } + + @Test + public void testSelfJoin() { + UnresolvedPlan plan = + parse("source=logs | fields id | join left=l right=r ON l.id = r.parent_id logs"); + Map results = visitor.analyze(plan); + + assertEquals(2, results.size()); + + for (Map.Entry entry : results.entrySet()) { + assertTrue(entry.getKey() instanceof Relation); + Relation relation = (Relation) entry.getKey(); + FieldResolutionResult result = entry.getValue(); + String tableName = relation.getTableQualifiedName().toString(); + + assertEquals("logs", tableName); + Set fields = result.getRegularFields(); + assertEquals(1, fields.size()); + if (fields.contains("id")) { + assertEquals("", result.getWildcard().toString()); + } else { + assertTrue(fields.contains("parent_id")); + assertEquals("*", result.getWildcard().toString()); + } + } + } + + @Test + public void testJoinWithAggregation() { + assertJoinRelationFields( + "source=logs1 | join left=l right=r ON l.id = r.id logs2 | stats count() by l.region", + Map.of( + "logs1", new FieldResolutionResult(Set.of("region", "id")), + "logs2", new FieldResolutionResult(Set.of("id")))); + } + + @Test + public void testJoinWithSubsearch() { + assertJoinRelationFields( + "source=idx1 | where b > 1 | join a [source=idx2 | where c > 2 ] | eval result = c * d", + Map.of( + "idx1", new FieldResolutionResult(Set.of("a", "b", "c", "d"), "*"), + "idx2", new FieldResolutionResult(Set.of("a", "c", "d"), "*"))); + } + + @Test + public void testWhereWithSubsearch() { + assertThrows( + "Filter by subquery is not supported with field resolution.", + IllegalArgumentException.class, + () -> + visitor.analyze( + parse( + "source=idx1 | where b in [source=idx2 | where a > 2 | fields b] | fields" + + " c, d"))); + } + + @Test + public void testRegexCommand() { + assertSingleRelationFields("source=logs | regex status='error.*'", Set.of("status"), "*"); + } + + @Test + public void testRexCommand() { + assertSingleRelationFields( + "source=logs | rex field=message \"(?[^@]+)@(?.+)\" | fields user, domain," + + " other*", + Set.of("message"), + "other*"); + } + + @Test + public void testPatternsCommand() { + assertSingleRelationFields( + "source=logs | patterns message by other method=brain mode=aggregation" + + " show_numbered_token=true | fields patterns_field, pattern_count, tokens, other*", + Set.of("message", "other"), + "other*"); + } + + @Test + public void testDedupeCommand() { + assertSingleRelationFields("source=logs | dedup host, status", Set.of("host", "status"), "*"); + } + + @Test + public void testReverseCommand() { + assertSingleRelationFields("source=logs | reverse", Set.of(), "*"); + } + + @Test + public void testHeadCommand() { + assertSingleRelationFields("source=logs | head 10", Set.of(), "*"); + } + + @Test + public void testRenameCommand() { + assertSingleRelationFields("source=logs | rename old_name as new_name", Set.of(), "*"); + } + + @Test + public void testUnimplementedVisitDetected() { + assertThrows( + "Unsupported command for field resolution: Kmeans", + IllegalArgumentException.class, + () -> visitor.analyze(parse("source=idx1 | kmeans centroids=3"))); + } +}