Skip to content

Add lambda function and array related functions #3584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
Expand Down Expand Up @@ -231,6 +232,10 @@ public T visitSort(Sort node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}

public T visitDedupe(Dedupe node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of lambda function. Params include function name (@funcName) and function
* arguments (@funcArgs)
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class LambdaFunction extends UnresolvedExpression {
private final UnresolvedExpression function;
private final List<QualifiedName> funcArgs;

@Override
public List<UnresolvedExpression> getChild() {
List<UnresolvedExpression> children = new ArrayList<>();
children.add(function);
children.addAll(funcArgs);
return children;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitLambdaFunction(this, context);
}

@Override
public String toString() {
return String.format(
"(%s) -> %s",
funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")),
function.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;

import java.sql.Connection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Stack;
import java.util.function.BiFunction;
import lombok.Getter;
import lombok.Setter;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.RelBuilder;
Expand Down Expand Up @@ -44,13 +47,16 @@ public class CalcitePlanContext {

private final Stack<RexCorrelVariable> correlVar = new Stack<>();

@Getter public Map<String, RexLambdaRef> temparolInputMap;

private CalcitePlanContext(FrameworkConfig config, QueryType queryType) {
this.config = config;
this.queryType = queryType;
this.connection = CalciteToolsHelper.connect(config, TYPE_FACTORY);
this.relBuilder = CalciteToolsHelper.create(config, TYPE_FACTORY, connection);
this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder());
this.functionProperties = new FunctionProperties(QueryType.PPL);
this.temparolInputMap = new HashMap<>();
}

public RexNode resolveJoinCondition(
Expand Down Expand Up @@ -82,7 +88,15 @@ public Optional<RexCorrelVariable> peekCorrelVar() {
}
}

public CalcitePlanContext clone() {
return new CalcitePlanContext(config, queryType);
}

public static CalcitePlanContext create(FrameworkConfig config, QueryType queryType) {
return new CalcitePlanContext(config, queryType);
}

public void putTemparolInputmapAll(Map<String, RexLambdaRef> candidateMap) {
this.temparolInputMap.putAll(candidateMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,29 @@
import static org.opensearch.sql.ast.expression.SpanUnit.NONE;
import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN;
import static org.opensearch.sql.calcite.utils.BuiltinFunctionUtils.VARCHAR_FORCE_NULLABLE;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.TransferUserDefinedFunction;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import lombok.RequiredArgsConstructor;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLambdaRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimeString;
Expand All @@ -39,6 +47,7 @@
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
Expand Down Expand Up @@ -284,6 +293,9 @@ public RexNode visitQualifiedName(QualifiedName node, CalcitePlanContext context
// TODO: Need to support nested fields https://github.com/opensearch-project/sql/issues/3459
// 2. resolve QualifiedName in non-join condition
String qualifiedName = node.toString();
if (context.getTemparolInputMap().containsKey(qualifiedName)) {
return context.getTemparolInputMap().get(qualifiedName);
}
List<String> currentFields = context.relBuilder.peek().getRowType().getFieldNames();
if (currentFields.contains(qualifiedName)) {
// 2.1 resolve QualifiedName from stack top
Expand Down Expand Up @@ -337,16 +349,112 @@ private boolean isTimeBased(SpanUnit unit) {
return !(unit == NONE || unit == UNKNOWN);
}

@Override
public RexNode visitLambdaFunction(LambdaFunction node, CalcitePlanContext context) {
try {
List<QualifiedName> names = node.getFuncArgs();
List<RexLambdaRef> args =
IntStream.range(0, names.size())
.mapToObj(
i ->
context.temparolInputMap.getOrDefault(
names.get(i).toString(),
new RexLambdaRef(
i,
names.get(i).toString(),
TYPE_FACTORY.createSqlType(SqlTypeName.ANY))))
.collect(Collectors.toList());
RexNode body = node.getFunction().accept(this, context);
RexNode lambdaNode = context.rexBuilder.makeLambdaCall(body, args);
return lambdaNode;
} catch (Exception e) {
throw new RuntimeException("Cannot create lambda function", e);
}
}

@Override
public RexNode visitLet(Let node, CalcitePlanContext context) {
RexNode expr = analyze(node.getExpression(), context);
return context.relBuilder.alias(expr, node.getVar().getField().toString());
}

/**
* The function will clone a context for lambda function. For lambda like (x, y, z) -> ..., we
* will map type for each lambda argument by the order of previous argument. Also, the function
* will add these variables to the context so they can pass visitQualifiedName
*/
private CalcitePlanContext prepareLambdaContext(
CalcitePlanContext context,
LambdaFunction node,
List<RexNode> previousArgument,
String functionName) {
try {
CalcitePlanContext lambdaContext = context.clone();
List<RelDataType> candidateType = new ArrayList<>();
candidateType.add(
((ArraySqlType) previousArgument.get(0).getType())
.getComponentType()); // The first argument should be array type
candidateType.addAll(previousArgument.stream().skip(1).map(RexNode::getType).toList());
candidateType = modifyLambdaTypeByFunction(functionName, candidateType);
List<QualifiedName> argNames = node.getFuncArgs();
Map<String, RexLambdaRef> lambdaTypes = new HashMap<>();
int candidateIndex;
candidateIndex = 0;
for (int i = 0; i < argNames.size(); i++) {
RelDataType type;
if (candidateIndex < candidateType.size()) {
type = candidateType.get(candidateIndex);
candidateIndex++;
} else {
type =
TYPE_FACTORY.createSqlType(
SqlTypeName.INTEGER); // For transform function, the i is missing in input.
}
lambdaTypes.put(
argNames.get(i).toString(), new RexLambdaRef(i, argNames.get(i).toString(), type));
}
lambdaContext.putTemparolInputmapAll(lambdaTypes);
return lambdaContext;
} catch (Exception e) {
throw new RuntimeException("Fail to prepare lambda context", e);
}
}

/**
* @param functionName function name
* @param originalType the argument type by order
* @return a modified types. Different functions need to implement its own order. Currently, only
* reduce has special logic.
*/
private List<RelDataType> modifyLambdaTypeByFunction(
String functionName, List<RelDataType> originalType) {
switch (functionName.toUpperCase(Locale.ROOT)) {
case "REDUCE": // For reduce case, the first type is acc should be any
return Stream.concat(
Stream.of(TYPE_FACTORY.createSqlType(SqlTypeName.ANY)),
originalType.subList(0, originalType.size() - 1).stream())
.collect(Collectors.toList());
default:
return originalType;
}
}

@Override
public RexNode visitFunction(Function node, CalcitePlanContext context) {
List<RexNode> arguments =
node.getFuncArgs().stream().map(arg -> analyze(arg, context)).collect(Collectors.toList());
List<UnresolvedExpression> args = node.getFuncArgs();
List<RexNode> arguments = new ArrayList<>();
for (UnresolvedExpression arg : args) {
if (arg instanceof LambdaFunction) {
CalcitePlanContext lambdaContext =
prepareLambdaContext(context, (LambdaFunction) arg, arguments, node.getFuncName());
arguments.add(analyze(arg, lambdaContext));
} else {
arguments.add(analyze(arg, context));
}
}
// List<RexNode> arguments =
// node.getFuncArgs().stream().map(arg -> analyze(arg,
// context)).collect(Collectors.toList());
RexNode resolvedNode =
PPLFuncImpTable.INSTANCE.resolveSafe(
context.rexBuilder, node.getFuncName(), arguments.toArray(new RexNode[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ public class UserDefinedFunctionUtils {
public static RelDataType nullableDateUDT = TYPE_FACTORY.createUDT(EXPR_DATE, true);
public static RelDataType nullableTimestampUDT =
TYPE_FACTORY.createUDT(ExprUDT.EXPR_TIMESTAMP, true);

public static SqlReturnTypeInference timestampInference =
ReturnTypes.explicit(nullableTimestampUDT);

public static SqlReturnTypeInference timeInference = ReturnTypes.explicit(nullableTimeUDT);

public static SqlReturnTypeInference dateInference = ReturnTypes.explicit(nullableDateUDT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ public enum BuiltinFunctionName {
TAN(FunctionName.of("tan")),
SPAN(FunctionName.of("span")),

/** Collection functions */
ARRAY(FunctionName.of("array")),
FORALL(FunctionName.of("forall")),
EXISTS(FunctionName.of("exists")),
FILTER(FunctionName.of("filter")),
TRANSFORM(FunctionName.of("transform")),
REDUCE(FunctionName.of("reduce")),

/** Date and Time Functions. */
ADDDATE(FunctionName.of("adddate")),
ADDTIME(FunctionName.of("addtime")),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function.CollectionUDF;

import static org.apache.calcite.sql.type.SqlTypeUtil.createArrayType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.expression.function.ImplementorUDF;

public class ArrayFunctionImpl extends ImplementorUDF {
public ArrayFunctionImpl() {
super(new ArrayImplementor(), NullPolicy.ANY);
}

@Override
public SqlReturnTypeInference getReturnTypeInference() {
return sqlOperatorBinding -> {
RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory();
List<RelDataType> argTypes = sqlOperatorBinding.collectOperandTypes();
RelDataType commonType = typeFactory.leastRestrictive(argTypes);
if (commonType == null) {
throw new IllegalArgumentException(
"All arguments in json array cannot be converted into one common types");
}
return createArrayType(
typeFactory, typeFactory.createTypeWithNullability(commonType, true), true);
};
}

public static class ArrayImplementor implements NotNullImplementor {
@Override
public Expression implement(
RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
RelDataType realType = call.getType().getComponentType();
List<Expression> newArgs = new ArrayList<>(translatedOperands);
assert realType != null;
newArgs.add(Expressions.constant(realType.getSqlTypeName()));
return Expressions.call(
Types.lookupMethod(ArrayFunctionImpl.class, "eval", Object[].class), newArgs);
}
}

public static Object eval(Object... args) {
SqlTypeName targetType = (SqlTypeName) args[args.length - 1];
switch (targetType) {
case DOUBLE:
List<Object> unboxed =
IntStream.range(0, args.length - 1)
.mapToObj(i -> ((Number) args[i]).doubleValue())
.collect(Collectors.toList());

return unboxed;
case FLOAT:
List<Object> unboxedFloat =
IntStream.range(0, args.length - 1)
.mapToObj(i -> ((Number) args[i]).floatValue())
.collect(Collectors.toList());
return unboxedFloat;
default:
return Arrays.asList(args).subList(0, args.length - 1);
}
}
}
Loading
Loading