diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 3bdd317af5..ef0bdf309c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -22,6 +22,7 @@ import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; @@ -231,6 +232,10 @@ public T visitSort(Sort node, C context) { return visitChildren(node, context); } + public T visitLambdaFunction(LambdaFunction node, C context) { + return visitChildren(node, context); + } + public T visitDedupe(Dedupe node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java new file mode 100644 index 0000000000..e7eab42765 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/LambdaFunction.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** + * Expression node of lambda function. Params include function name (@funcName) and function + * arguments (@funcArgs) + */ +@Getter +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +public class LambdaFunction extends UnresolvedExpression { + private final UnresolvedExpression function; + private final List funcArgs; + + @Override + public List getChild() { + List children = new ArrayList<>(); + children.add(function); + children.addAll(funcArgs); + return children; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitLambdaFunction(this, context); + } + + @Override + public String toString() { + return String.format( + "(%s) -> %s", + funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")), + function.toString()); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java index 189db7d03b..a13266bf49 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java @@ -8,12 +8,15 @@ import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; import java.sql.Connection; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; import java.util.Stack; import java.util.function.BiFunction; import lombok.Getter; import lombok.Setter; import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.RelBuilder; @@ -44,6 +47,8 @@ public class CalcitePlanContext { private final Stack correlVar = new Stack<>(); + @Getter public Map temparolInputMap; + private CalcitePlanContext(FrameworkConfig config, QueryType queryType) { this.config = config; this.queryType = queryType; @@ -51,6 +56,7 @@ private CalcitePlanContext(FrameworkConfig config, QueryType queryType) { this.relBuilder = CalciteToolsHelper.create(config, TYPE_FACTORY, connection); this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder()); this.functionProperties = new FunctionProperties(QueryType.PPL); + this.temparolInputMap = new HashMap<>(); } public RexNode resolveJoinCondition( @@ -82,7 +88,15 @@ public Optional peekCorrelVar() { } } + public CalcitePlanContext clone() { + return new CalcitePlanContext(config, queryType); + } + public static CalcitePlanContext create(FrameworkConfig config, QueryType queryType) { return new CalcitePlanContext(config, queryType); } + + public void putTemparolInputmapAll(Map candidateMap) { + this.temparolInputMap.putAll(candidateMap); + } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index 7c6c4f1724..14de48ca42 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -8,21 +8,29 @@ import static org.opensearch.sql.ast.expression.SpanUnit.NONE; import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; import static org.opensearch.sql.calcite.utils.BuiltinFunctionUtils.VARCHAR_FORCE_NULLABLE; +import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedFunction; import java.math.BigDecimal; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; import lombok.RequiredArgsConstructor; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.ArraySqlType; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.DateString; import org.apache.calcite.util.TimeString; @@ -39,6 +47,7 @@ import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.LambdaFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -284,6 +293,9 @@ public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context // TODO: Need to support nested fields https://github.com/opensearch-project/sql/issues/3459 // 2. resolve QualifiedName in non-join condition String qualifiedName = node.toString(); + if (context.getTemparolInputMap().containsKey(qualifiedName)) { + return context.getTemparolInputMap().get(qualifiedName); + } List currentFields = context.relBuilder.peek().getRowType().getFieldNames(); if (currentFields.contains(qualifiedName)) { // 2.1 resolve QualifiedName from stack top @@ -337,16 +349,112 @@ private boolean isTimeBased(SpanUnit unit) { return !(unit == NONE || unit == UNKNOWN); } + @Override + public RexNode visitLambdaFunction(LambdaFunction node, CalcitePlanContext context) { + try { + List names = node.getFuncArgs(); + List args = + IntStream.range(0, names.size()) + .mapToObj( + i -> + context.temparolInputMap.getOrDefault( + names.get(i).toString(), + new RexLambdaRef( + i, + names.get(i).toString(), + TYPE_FACTORY.createSqlType(SqlTypeName.ANY)))) + .collect(Collectors.toList()); + RexNode body = node.getFunction().accept(this, context); + RexNode lambdaNode = context.rexBuilder.makeLambdaCall(body, args); + return lambdaNode; + } catch (Exception e) { + throw new RuntimeException("Cannot create lambda function", e); + } + } + @Override public RexNode visitLet(Let node, CalcitePlanContext context) { RexNode expr = analyze(node.getExpression(), context); return context.relBuilder.alias(expr, node.getVar().getField().toString()); } + /** + * The function will clone a context for lambda function. For lambda like (x, y, z) -> ..., we + * will map type for each lambda argument by the order of previous argument. Also, the function + * will add these variables to the context so they can pass visitQualifiedName + */ + private CalcitePlanContext prepareLambdaContext( + CalcitePlanContext context, + LambdaFunction node, + List previousArgument, + String functionName) { + try { + CalcitePlanContext lambdaContext = context.clone(); + List candidateType = new ArrayList<>(); + candidateType.add( + ((ArraySqlType) previousArgument.get(0).getType()) + .getComponentType()); // The first argument should be array type + candidateType.addAll(previousArgument.stream().skip(1).map(RexNode::getType).toList()); + candidateType = modifyLambdaTypeByFunction(functionName, candidateType); + List argNames = node.getFuncArgs(); + Map lambdaTypes = new HashMap<>(); + int candidateIndex; + candidateIndex = 0; + for (int i = 0; i < argNames.size(); i++) { + RelDataType type; + if (candidateIndex < candidateType.size()) { + type = candidateType.get(candidateIndex); + candidateIndex++; + } else { + type = + TYPE_FACTORY.createSqlType( + SqlTypeName.INTEGER); // For transform function, the i is missing in input. + } + lambdaTypes.put( + argNames.get(i).toString(), new RexLambdaRef(i, argNames.get(i).toString(), type)); + } + lambdaContext.putTemparolInputmapAll(lambdaTypes); + return lambdaContext; + } catch (Exception e) { + throw new RuntimeException("Fail to prepare lambda context", e); + } + } + + /** + * @param functionName function name + * @param originalType the argument type by order + * @return a modified types. Different functions need to implement its own order. Currently, only + * reduce has special logic. + */ + private List modifyLambdaTypeByFunction( + String functionName, List originalType) { + switch (functionName.toUpperCase(Locale.ROOT)) { + case "REDUCE": // For reduce case, the first type is acc should be any + return Stream.concat( + Stream.of(TYPE_FACTORY.createSqlType(SqlTypeName.ANY)), + originalType.subList(0, originalType.size() - 1).stream()) + .collect(Collectors.toList()); + default: + return originalType; + } + } + @Override public RexNode visitFunction(Function node, CalcitePlanContext context) { - List arguments = - node.getFuncArgs().stream().map(arg -> analyze(arg, context)).collect(Collectors.toList()); + List args = node.getFuncArgs(); + List arguments = new ArrayList<>(); + for (UnresolvedExpression arg : args) { + if (arg instanceof LambdaFunction) { + CalcitePlanContext lambdaContext = + prepareLambdaContext(context, (LambdaFunction) arg, arguments, node.getFuncName()); + arguments.add(analyze(arg, lambdaContext)); + } else { + arguments.add(analyze(arg, context)); + } + } + // List arguments = + // node.getFuncArgs().stream().map(arg -> analyze(arg, + // context)).collect(Collectors.toList()); RexNode resolvedNode = PPLFuncImpTable.INSTANCE.resolveSafe( context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0])); diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index 176de3474a..026fea6c7c 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -56,10 +56,8 @@ public class UserDefinedFunctionUtils { public static RelDataType nullableDateUDT = TYPE_FACTORY.createUDT(EXPR_DATE, true); public static RelDataType nullableTimestampUDT = TYPE_FACTORY.createUDT(ExprUDT.EXPR_TIMESTAMP, true); - public static SqlReturnTypeInference timestampInference = ReturnTypes.explicit(nullableTimestampUDT); - public static SqlReturnTypeInference timeInference = ReturnTypes.explicit(nullableTimeUDT); public static SqlReturnTypeInference dateInference = ReturnTypes.explicit(nullableDateUDT); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 432eb09d49..9cdbc3067e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -56,6 +56,14 @@ public enum BuiltinFunctionName { TAN(FunctionName.of("tan")), SPAN(FunctionName.of("span")), + /** Collection functions */ + ARRAY(FunctionName.of("array")), + FORALL(FunctionName.of("forall")), + EXISTS(FunctionName.of("exists")), + FILTER(FunctionName.of("filter")), + TRANSFORM(FunctionName.of("transform")), + REDUCE(FunctionName.of("reduce")), + /** Date and Time Functions. */ ADDDATE(FunctionName.of("adddate")), ADDTIME(FunctionName.of("addtime")), diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ArrayFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ArrayFunctionImpl.java new file mode 100644 index 0000000000..aa4501c33c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ArrayFunctionImpl.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.expression.function.ImplementorUDF; + +public class ArrayFunctionImpl extends ImplementorUDF { + public ArrayFunctionImpl() { + super(new ArrayImplementor(), NullPolicy.ANY); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return sqlOperatorBinding -> { + RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory(); + List argTypes = sqlOperatorBinding.collectOperandTypes(); + RelDataType commonType = typeFactory.leastRestrictive(argTypes); + if (commonType == null) { + throw new IllegalArgumentException( + "All arguments in json array cannot be converted into one common types"); + } + return createArrayType( + typeFactory, typeFactory.createTypeWithNullability(commonType, true), true); + }; + } + + public static class ArrayImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + RelDataType realType = call.getType().getComponentType(); + List newArgs = new ArrayList<>(translatedOperands); + assert realType != null; + newArgs.add(Expressions.constant(realType.getSqlTypeName())); + return Expressions.call( + Types.lookupMethod(ArrayFunctionImpl.class, "eval", Object[].class), newArgs); + } + } + + public static Object eval(Object... args) { + SqlTypeName targetType = (SqlTypeName) args[args.length - 1]; + switch (targetType) { + case DOUBLE: + List unboxed = + IntStream.range(0, args.length - 1) + .mapToObj(i -> ((Number) args[i]).doubleValue()) + .collect(Collectors.toList()); + + return unboxed; + case FLOAT: + List unboxedFloat = + IntStream.range(0, args.length - 1) + .mapToObj(i -> ((Number) args[i]).floatValue()) + .collect(Collectors.toList()); + return unboxedFloat; + default: + return Arrays.asList(args).subList(0, args.length - 1); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ExistsFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ExistsFunctionImpl.java new file mode 100644 index 0000000000..7508514af5 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ExistsFunctionImpl.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import java.util.List; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexImpTable; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.expression.function.ImplementorUDF; + +public class ExistsFunctionImpl extends ImplementorUDF { + public ExistsFunctionImpl() { + super(new ExistsImplementor(), NullPolicy.ALL); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return ReturnTypes.BOOLEAN; + } + + public static class ExistsImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + ScalarFunctionImpl function = + (ScalarFunctionImpl) + ScalarFunctionImpl.create( + Types.lookupMethod(ExistsFunctionImpl.class, "eval", Object[].class)); + return function.getImplementor().implement(translator, call, RexImpTable.NullAs.NULL); + } + } + + public static Object eval(Object... args) { + org.apache.calcite.linq4j.function.Function1 lambdaFunction = + (org.apache.calcite.linq4j.function.Function1) args[1]; + List target = (List) args[0]; + try { + for (Object candidate : target) { + if ((Boolean) lambdaFunction.apply(candidate)) { + return true; + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return false; + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/FilterFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/FilterFunctionImpl.java new file mode 100644 index 0000000000..efecb9b569 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/FilterFunctionImpl.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexImpTable; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.expression.function.ImplementorUDF; + +public class FilterFunctionImpl extends ImplementorUDF { + public FilterFunctionImpl() { + super(new FilterImplementor(), NullPolicy.ANY); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return ReturnTypes.ARG0; + } + + public static class FilterImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + ScalarFunctionImpl function = + (ScalarFunctionImpl) + ScalarFunctionImpl.create( + Types.lookupMethod(FilterFunctionImpl.class, "eval", Object[].class)); + return function.getImplementor().implement(translator, call, RexImpTable.NullAs.NULL); + } + } + + public static Object eval(Object... args) { + org.apache.calcite.linq4j.function.Function1 lambdaFunction = + (org.apache.calcite.linq4j.function.Function1) args[1]; + List target = (List) args[0]; + List results = new ArrayList<>(); + try { + for (Object candidate : target) { + if ((Boolean) lambdaFunction.apply(candidate)) { + results.add(candidate); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return results; + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ForallFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ForallFunctionImpl.java new file mode 100644 index 0000000000..54a35691a0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ForallFunctionImpl.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import java.util.List; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexImpTable; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.schema.impl.ScalarFunctionImpl; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.opensearch.sql.expression.function.ImplementorUDF; + +public class ForallFunctionImpl extends ImplementorUDF { + public ForallFunctionImpl() { + super(new ForallImplementor(), NullPolicy.ALL); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return ReturnTypes.BOOLEAN; + } + + public static class ForallImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + ScalarFunctionImpl function = + (ScalarFunctionImpl) + ScalarFunctionImpl.create( + Types.lookupMethod(ForallFunctionImpl.class, "eval", Object[].class)); + return function.getImplementor().implement(translator, call, RexImpTable.NullAs.NULL); + } + } + + public static Object eval(Object... args) { + org.apache.calcite.linq4j.function.Function1 lambdaFunction = + (org.apache.calcite.linq4j.function.Function1) args[1]; + List target = (List) args[0]; + try { + for (Object candidate : target) { + if (!(Boolean) lambdaFunction.apply(candidate)) { + return false; + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return true; + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/LambdaUtils.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/LambdaUtils.java new file mode 100644 index 0000000000..6f3fc1f8b0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/LambdaUtils.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCallBinding; +import org.apache.calcite.rex.RexLambda; +import org.apache.calcite.rex.RexLambdaRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; + +public class LambdaUtils { + public static Object transferLambdaOutputToTargetType(Object candidate, SqlTypeName targetType) { + if (candidate instanceof BigDecimal) { + BigDecimal bd = (BigDecimal) candidate; + switch (targetType) { + case INTEGER: + return bd.intValue(); + case DOUBLE: + return bd.doubleValue(); + case FLOAT: + return bd.floatValue(); + default: + return bd; + } + } else { + return candidate; + } + } + + public static RelDataType inferReturnTypeFromLambda( + RexLambda rexLambda, Map filledTypes, RelDataTypeFactory typeFactory) { + RexCall rexCall = (RexCall) rexLambda.getExpression(); + SqlReturnTypeInference returnInfer = rexCall.getOperator().getReturnTypeInference(); + List lambdaOperands = rexCall.getOperands(); + List filledOperands = new ArrayList<>(); + for (RexNode rexNode : lambdaOperands) { + if (rexNode instanceof RexLambdaRef rexLambdaRef) { + if (rexLambdaRef.getType().getSqlTypeName() == SqlTypeName.ANY) { + filledOperands.add( + new RexLambdaRef( + rexLambdaRef.getIndex(), + rexLambdaRef.getName(), + filledTypes.get(rexLambdaRef.getName()))); + } else { + filledOperands.add(rexNode); + } + } else if (rexNode instanceof RexCall) { + filledOperands.add( + reinferReturnTypeForRexCallInsideLambda((RexCall) rexNode, filledTypes, typeFactory)); + } else { + filledOperands.add(rexNode); + } + } + return returnInfer.inferReturnType( + new RexCallBinding(typeFactory, rexCall.getOperator(), filledOperands, List.of())); + } + + public static RexCall reinferReturnTypeForRexCallInsideLambda( + RexCall rexCall, Map argTypes, RelDataTypeFactory typeFactory) { + List filledOperands = new ArrayList<>(); + List rexCallOperands = rexCall.getOperands(); + for (RexNode rexNode : rexCallOperands) { + if (rexNode instanceof RexLambdaRef rexLambdaRef) { + filledOperands.add( + new RexLambdaRef( + rexLambdaRef.getIndex(), + rexLambdaRef.getName(), + argTypes.get(rexLambdaRef.getName()))); + } else if (rexNode instanceof RexCall) { + filledOperands.add( + reinferReturnTypeForRexCallInsideLambda((RexCall) rexNode, argTypes, typeFactory)); + } else { + filledOperands.add(rexNode); + } + } + RelDataType returnType = + rexCall + .getOperator() + .inferReturnType( + new RexCallBinding(typeFactory, rexCall.getOperator(), filledOperands, List.of())); + return rexCall.clone(returnType, filledOperands); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ReduceFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ReduceFunctionImpl.java new file mode 100644 index 0000000000..4b655c3b09 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/ReduceFunctionImpl.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import static org.opensearch.sql.expression.function.CollectionUDF.LambdaUtils.inferReturnTypeFromLambda; +import static org.opensearch.sql.expression.function.CollectionUDF.LambdaUtils.transferLambdaOutputToTargetType; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCallBinding; +import org.apache.calcite.rex.RexLambda; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.ArraySqlType; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.expression.function.ImplementorUDF; + +public class ReduceFunctionImpl extends ImplementorUDF { + public ReduceFunctionImpl() { + super(new ReduceImplementor(), NullPolicy.ANY); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return sqlOperatorBinding -> { + RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory(); + RexCallBinding rexCallBinding = (RexCallBinding) sqlOperatorBinding; + List rexNodes = rexCallBinding.operands(); + ArraySqlType listType = (ArraySqlType) rexNodes.get(0).getType(); + RelDataType elementType = listType.getComponentType(); + RelDataType baseType = rexNodes.get(1).getType(); + Map map = new HashMap<>(); + RexLambda mergeLambda = (RexLambda) rexNodes.get(2); + map.put(mergeLambda.getParameters().get(0).getName(), baseType); + map.put(mergeLambda.getParameters().get(1).getName(), elementType); + RelDataType mergedReturnType = + inferReturnTypeFromLambda((RexLambda) rexNodes.get(2), map, typeFactory); + if (mergedReturnType != baseType) { // For different acc, we need to recalculate + map.put(mergeLambda.getParameters().get(0).getName(), mergedReturnType); + mergedReturnType = inferReturnTypeFromLambda((RexLambda) rexNodes.get(2), map, typeFactory); + } + RelDataType finalReturnType; + if (rexNodes.size() > 3) { + finalReturnType = inferReturnTypeFromLambda((RexLambda) rexNodes.get(3), map, typeFactory); + } else { + finalReturnType = mergedReturnType; + } + return finalReturnType; + + /* + RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory(); + RexCallBinding rexCallBinding = (RexCallBinding) sqlOperatorBinding; + List operands = rexCallBinding.operands(); + RelDataType mergedReturnType = + ((RexLambda) operands.get(2)).getExpression().getType(); + if (operands.size() > 3) { + RelDataType reduceReturnType = + ((RexLambda) operands.get(3)).getExpression().getType(); + return typeFactory.leastRestrictive(List.of(mergedReturnType, reduceReturnType)); + } + return mergedReturnType; + + */ + }; + } + + public static class ReduceImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + List withReturnTypeList = new ArrayList<>(translatedOperands); + withReturnTypeList.add(Expressions.constant(call.getType().getSqlTypeName())); + return Expressions.call( + Types.lookupMethod(ReduceFunctionImpl.class, "eval", Object[].class), withReturnTypeList); + } + } + + public static Object eval(Object... args) { + List list = (List) args[0]; + SqlTypeName returnTypes = (SqlTypeName) args[args.length - 1]; + Object base = args[1]; + if (args[2] instanceof org.apache.calcite.linq4j.function.Function2) { + org.apache.calcite.linq4j.function.Function2 lambdaFunction = + (org.apache.calcite.linq4j.function.Function2) args[2]; + + try { + for (int i = 0; i < list.size(); i++) { + base = + transferLambdaOutputToTargetType( + lambdaFunction.apply(base, list.get(i)), returnTypes); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + if (args.length == 5) { + if (args[3] instanceof org.apache.calcite.linq4j.function.Function1) { + return transferLambdaOutputToTargetType( + ((org.apache.calcite.linq4j.function.Function1) args[3]).apply(base), returnTypes); + } else { + throw new IllegalArgumentException("wrong lambda function input"); + } + } else { + return base; + } + } else { + throw new IllegalArgumentException("wrong lambda function input"); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/TransformFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/TransformFunctionImpl.java new file mode 100644 index 0000000000..79c31dfc99 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/CollectionUDF/TransformFunctionImpl.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function.CollectionUDF; + +import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType; +import static org.opensearch.sql.expression.function.CollectionUDF.LambdaUtils.transferLambdaOutputToTargetType; + +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.adapter.enumerable.NotNullImplementor; +import org.apache.calcite.adapter.enumerable.NullPolicy; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.Types; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCallBinding; +import org.apache.calcite.rex.RexLambda; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.ArraySqlType; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.expression.function.ImplementorUDF; + +public class TransformFunctionImpl extends ImplementorUDF { + public TransformFunctionImpl() { + super(new TransformImplementor(), NullPolicy.ANY); + } + + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return sqlOperatorBinding -> { + RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory(); + RexCallBinding rexCallBinding = (RexCallBinding) sqlOperatorBinding; + List operands = rexCallBinding.operands(); + RelDataType lambdaReturnType = ((RexLambda) operands.get(1)).getExpression().getType(); + return createArrayType( + typeFactory, typeFactory.createTypeWithNullability(lambdaReturnType, true), true); + }; + } + + public static class TransformImplementor implements NotNullImplementor { + @Override + public Expression implement( + RexToLixTranslator translator, RexCall call, List translatedOperands) { + ArraySqlType arrayType = (ArraySqlType) call.getType(); + List withReturnTypeList = new ArrayList<>(translatedOperands); + withReturnTypeList.add(Expressions.constant(arrayType.getComponentType().getSqlTypeName())); + return Expressions.call( + Types.lookupMethod(TransformFunctionImpl.class, "eval", Object[].class), + withReturnTypeList); + } + } + + public static Object eval(Object... args) { + List target = (List) args[0]; + List results = new ArrayList<>(); + SqlTypeName returnType = (SqlTypeName) args[args.length - 1]; + if (args[1] instanceof org.apache.calcite.linq4j.function.Function1) { + org.apache.calcite.linq4j.function.Function1 lambdaFunction = + (org.apache.calcite.linq4j.function.Function1) args[1]; + + try { + for (Object candidate : target) { + results.add( + transferLambdaOutputToTargetType(lambdaFunction.apply(candidate), returnType)); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return results; + } else if (args[1] instanceof org.apache.calcite.linq4j.function.Function2) { + org.apache.calcite.linq4j.function.Function2 lambdaFunction = + (org.apache.calcite.linq4j.function.Function2) args[1]; + try { + for (int i = 0; i < target.size(); i++) { + results.add( + transferLambdaOutputToTargetType(lambdaFunction.apply(target.get(i), i), returnType)); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return results; + } else { + throw new IllegalArgumentException("wrong lambda function input"); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index e1ff69660b..af621e965d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -16,12 +16,25 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; +import org.opensearch.sql.expression.function.CollectionUDF.ArrayFunctionImpl; +import org.opensearch.sql.expression.function.CollectionUDF.ExistsFunctionImpl; +import org.opensearch.sql.expression.function.CollectionUDF.FilterFunctionImpl; +import org.opensearch.sql.expression.function.CollectionUDF.ForallFunctionImpl; +import org.opensearch.sql.expression.function.CollectionUDF.ReduceFunctionImpl; +import org.opensearch.sql.expression.function.CollectionUDF.TransformFunctionImpl; /** Defines functions and operators that are implemented only by PPL */ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { public static final SqlOperator SPAN = new SpanFunctionImpl().toUDF("SPAN"); + public static final SqlOperator FORALL = new ForallFunctionImpl().toUDF("forall"); + public static final SqlOperator EXISTS = new ExistsFunctionImpl().toUDF("exists"); + public static final SqlOperator ARRAY = new ArrayFunctionImpl().toUDF("array"); + public static final SqlOperator FILTER = new FilterFunctionImpl().toUDF("filter"); + public static final SqlOperator TRANSFORM = new TransformFunctionImpl().toUDF("transform"); + public static final SqlOperator REDUCE = new ReduceFunctionImpl().toUDF("reduce"); + /** * Invoking an implementor registered in {@link RexImpTable}, need to use reflection since they're * all private Use method directly in {@link BuiltInMethod} if possible, most operators' diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index c5bd3db466..357bb1cf4c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -8,65 +8,7 @@ import static java.lang.Math.E; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ABS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ACOS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADD; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.AND; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ASCII; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ASIN; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ATAN; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ATAN2; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.CBRT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.CEILING; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.CONCAT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.CONCAT_WS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.COS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.COT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.DEGREES; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.EXP; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.FLOOR; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.GREATER; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.GTE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LEFT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LENGTH; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LESS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LN; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOG; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOG10; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOG2; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LOWER; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTRIM; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.MULTIPLY; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOTEQUAL; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.OR; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.PI; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.POSITION; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.POW; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.POWER; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.RADIANS; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.RAND; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.REGEXP; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.REVERSE; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.RIGHT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.ROUND; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.RTRIM; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIGN; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SIN; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SPAN; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.STRCMP; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBSTR; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBSTRING; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.TRIM; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.TYPEOF; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.UPPER; -import static org.opensearch.sql.expression.function.BuiltinFunctionName.XOR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.*; import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; @@ -272,6 +214,13 @@ void populate() { // Register PPL UDF operator registerOperator(SPAN, PPLBuiltinOperators.SPAN); + registerOperator(ARRAY, PPLBuiltinOperators.ARRAY); + registerOperator(FORALL, PPLBuiltinOperators.FORALL); + registerOperator(EXISTS, PPLBuiltinOperators.EXISTS); + registerOperator(FILTER, PPLBuiltinOperators.FILTER); + registerOperator(TRANSFORM, PPLBuiltinOperators.TRANSFORM); + registerOperator(REDUCE, PPLBuiltinOperators.REDUCE); + // Register implementation. // Note, make the implementation an individual class if too complex. register( diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalciteArrayFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalciteArrayFunctionIT.java new file mode 100644 index 0000000000..eeba914114 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalciteArrayFunctionIT.java @@ -0,0 +1,170 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.standalone; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.*; + +import java.io.IOException; +import java.util.List; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; + +public class CalciteArrayFunctionIT extends CalcitePPLIntegTestCase { + @Override + public void init() throws IOException { + super.init(); + loadIndex(Index.BANK); + } + + @Test + public void testForAll() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, -1, 2), result = forall(array, x -> x > 0) |" + + " fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "boolean")); + + verifyDataRows(actual, rows(false)); + } + + @Test + public void testExists() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, -1, 2), result = exists(array, x -> x > 0) |" + + " fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "boolean")); + + verifyDataRows(actual, rows(true)); + } + + @Test + public void testFilter() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, -1, 2), result = filter(array, x -> x > 0) |" + + " fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "array")); + + verifyDataRows(actual, rows(List.of(1, 2))); + } + + @Test + public void testTransform() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, 2, 3), result = transform(array, x -> x + 1) |" + + " fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "array")); + + verifyDataRows(actual, rows(List.of(2, 3, 4))); + } + + @Test + public void testTransformForTwoInput() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, 2, 3), result = transform(array, (x, i) -> x +" + + " i) | fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "array")); + + verifyDataRows(actual, rows(List.of(1, 3, 5))); + } + + @Test + public void testTransformForWithDouble() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, 2, 3), result = transform(array, (x, i) -> x +" + + " i * 10.1) | fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "array")); + + verifyDataRows(actual, rows(List.of(1, 12.1, 23.2))); + } + + @Test + public void testTransformForWithUDF() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(TIMESTAMP('2000-01-02 00:00:00')," + + " TIMESTAMP('2000-01-03 00:00:00'), TIMESTAMP('2000-01-04 00:00:00')), result" + + " = transform(array, (x, i) -> DATEDIFF(x, TIMESTAMP('2000-01-01 23:59:59'))" + + " + i * 10.1) | fields result | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result", "array")); + + verifyDataRows(actual, rows(List.of(1, 12.1, 23.2))); + } + + @Test + public void testReduce() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1, 2, 3), result = reduce(array, 0, (acc, x) -> acc" + + " + x), result2 = reduce(array, 10, (acc, x) -> acc + x), result3 =" + + " reduce(array, 0, (acc, x) -> acc + x, acc -> acc * 10.0) | fields" + + " result,result2, result3 | head 1", + TEST_INDEX_BANK)); + + verifySchema( + actual, + schema("result", "integer"), + schema("result2", "integer"), + schema("result3", "double")); + + verifyDataRows(actual, rows(6, 16, 60)); + } + + @Test + public void testReduce2() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array(1.0, 2.0, 3.0), result3 = reduce(array, 0, (acc, x)" + + " -> acc * 10.0 + x, acc -> acc * 10.0) | fields result3 | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result3", "double")); + + verifyDataRows(actual, rows(1230)); + } + + @Test + public void testReduceWithUDF() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval array = array('a', 'ab', 'abc'), result3 = reduce(array, 0, (acc," + + " x) -> acc + length(x), acc -> acc * 10.0) | fields result3 | head 1", + TEST_INDEX_BANK)); + + verifySchema(actual, schema("result3", "double")); + + verifyDataRows(actual, rows(60)); + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 4dd27b2092..7d55fece47 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -191,6 +191,7 @@ RT_SQR_PRTHS: ']'; SINGLE_QUOTE: '\''; DOUBLE_QUOTE: '"'; BACKTICK: '`'; +ARROW: '->'; // Operators. Bit @@ -352,6 +353,13 @@ ISNOTNULL: 'ISNOTNULL'; CIDRMATCH: 'CIDRMATCH'; BETWEEN: 'BETWEEN'; +// COLLECTION FUNCTIONS +ARRAY: 'ARRAY'; +FORALL: 'FORALL'; +FILTER: 'FILTER'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; + // JSON FUNCTIONS JSON_VALID: 'JSON_VALID'; JSON: 'JSON'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 93098478ea..efce4d7e75 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -431,6 +431,7 @@ valueExpression | timestampFunction # timestampFunctionCall | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | lambda # lambdaExpr ; primaryExpression @@ -548,6 +549,8 @@ evalFunctionName | positionFunctionName | jsonFunctionName | geoipFunctionName + | collectionFunctionName + | lambdaFunctionName ; functionArgs @@ -555,7 +558,18 @@ functionArgs ; functionArg - : (ident EQUAL)? expression + : (ident EQUAL)? functionArgExpression + ; + + +functionArgExpression + : lambda + | expression + ; + +lambda + : ident ARROW expression + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression ; relevanceArg @@ -653,6 +667,18 @@ geoipFunctionName : GEOIP ; +collectionFunctionName + : ARRAY + ; + +lambdaFunctionName + : FORALL + | EXISTS + | FILTER + | TRANSFORM + | REDUCE + ; + trigonometricFunctionName : ACOS | ASIN @@ -1000,6 +1026,7 @@ keywordsCanBeId | singleFieldRelevanceFunctionName | multiFieldRelevanceFunctionName | commandName + | collectionFunctionName | comparisonOperator | patternMethod | explainMode @@ -1007,6 +1034,7 @@ keywordsCanBeId | CASE | ELSE | IN + | ARROW | BETWEEN | EXISTS | SOURCE diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index bc8e89387d..affcb1eafd 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -506,7 +506,8 @@ public UnresolvedPlan visitTableFunction(TableFunctionContext ctx) { arg -> { String argName = (arg.ident() != null) ? arg.ident().getText() : null; builder.add( - new UnresolvedArgument(argName, this.internalVisitExpression(arg.expression()))); + new UnresolvedArgument( + argName, this.internalVisitExpression(arg.functionArgExpression()))); }); return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 0755d035de..535ea85921 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -44,6 +44,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -132,6 +133,17 @@ public UnresolvedExpression visitLogicalXor(LogicalXorContext ctx) { return new Xor(visit(ctx.left), visit(ctx.right)); } + /** lambda expression */ + @Override + public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { + List arguments = + ctx.ident().stream() + .map(x -> this.visitIdentifiers(Collections.singletonList(x))) + .collect(Collectors.toList()); + UnresolvedExpression function = visitExpression(ctx.expression()); + return new LambdaFunction(function, arguments); + } + /** Comparison expression. */ @Override public UnresolvedExpression visitCompareExpr(CompareExprContext ctx) {