diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java index 7c0872b67fc..337e89f3f15 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; import java.sql.Connection; +import java.sql.SQLException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -19,10 +20,13 @@ import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.server.CalciteServerStatement; +import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.RelBuilder; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.calcite.utils.CalciteToolsHelper; +import org.opensearch.sql.calcite.validate.TypeChecker; import org.opensearch.sql.executor.QueryType; import org.opensearch.sql.expression.function.FunctionProperties; @@ -35,6 +39,7 @@ public class CalcitePlanContext { public final FunctionProperties functionProperties; public final QueryType queryType; public final Integer querySizeLimit; + @Getter public final SqlValidator validator; @Getter @Setter private boolean isResolvingJoinCondition = false; @Getter @Setter private boolean isResolvingSubquery = false; @@ -61,6 +66,13 @@ private CalcitePlanContext(FrameworkConfig config, Integer querySizeLimit, Query this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder()); this.functionProperties = new FunctionProperties(QueryType.PPL); this.rexLambdaRefMap = new HashMap<>(); + final CalciteServerStatement statement; + try { + statement = connection.createStatement().unwrap(CalciteServerStatement.class); + } catch (SQLException e) { + throw new RuntimeException(e); + } + this.validator = TypeChecker.getValidator(statement, config); } public RexNode resolveJoinCondition( diff --git a/core/src/main/java/org/opensearch/sql/calcite/PplRelToSqlConverter.java b/core/src/main/java/org/opensearch/sql/calcite/PplRelToSqlConverter.java new file mode 100644 index 00000000000..3899c8eed1d --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/PplRelToSqlConverter.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite; + +import org.apache.calcite.rel.rel2sql.RelToSqlConverter; +import org.apache.calcite.sql.SqlDialect; + +/** + * An extension of {@link RelToSqlConverter} to convert a relation algebra tree, translated from of + * PPL query, into a SQL statement. + * + *

Currently, we haven't implemented any specific change to it, just leaving it for future + * extension. + */ +public class PplRelToSqlConverter extends RelToSqlConverter { + /** + * Creates a RelToSqlConverter. + * + * @param dialect the SQL dialect to use + */ + public PplRelToSqlConverter(SqlDialect dialect) { + super(dialect); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java b/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java index 14c8d8f369e..a814d8edc41 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/CalciteToolsHelper.java @@ -80,6 +80,8 @@ import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorUtil; import org.apache.calcite.sql2rel.SqlRexConvertletTable; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.Frameworks; @@ -90,6 +92,7 @@ import org.opensearch.sql.calcite.CalcitePlanContext; import org.opensearch.sql.calcite.plan.Scannable; import org.opensearch.sql.calcite.udf.udaf.NullableSqlAvgAggFunction; +import org.opensearch.sql.calcite.validate.PplOpTable; /** * Calcite Tools Helper. This class is used to create customized: 1. Connection 2. JavaTypeFactory @@ -240,7 +243,7 @@ public R perform( * return {@link OpenSearchCalcitePreparingStmt} */ @Override - protected CalcitePrepareImpl.CalcitePreparingStmt getPreparingStmt( + public CalcitePrepareImpl.CalcitePreparingStmt getPreparingStmt( CalcitePrepare.Context context, Type elementType, CalciteCatalogReader catalogReader, @@ -332,6 +335,25 @@ public Type getElementType() { } return super.implement(root); } + + /** + * Imitated {@link org.apache.calcite.prepare.CalcitePrepareImpl}#createSqlValidator to create a + * SqlValidator + */ + protected SqlValidator createSqlValidator(CalciteCatalogReader catalogReader) { + return SqlValidatorUtil.newValidator( + // this is different from the original implementation + PplOpTable.getInstance(), + catalogReader, + context.getTypeFactory(), + // this may be customized in the future + SqlValidator.Config.DEFAULT); + } + + @Override + public SqlValidator getSqlValidator() { + return super.getSqlValidator(); + } } public static class OpenSearchRelRunners { diff --git a/core/src/main/java/org/opensearch/sql/calcite/validate/PplOpTable.java b/core/src/main/java/org/opensearch/sql/calcite/validate/PplOpTable.java new file mode 100644 index 00000000000..56b8b88c12e --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/validate/PplOpTable.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.validate; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSyntax; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * PPLOpTable is a custom implementation of {@link SqlOperatorTable} that provides a way to register + * and look up PPL operators. + */ +public class PplOpTable implements SqlOperatorTable { + // Implementation notes: + // - Did not extend ListSqlOperatorTable because it does not support registering multiple + // SqlOperator to one name. + // - Did not extend ReflectiveSqlOperatorTable because it relies on reflectively looking for + // member fields of + // SqlOperator type, which is not suitable for our use case. + // - Did not add SqlOperatorTable to PPLFuncImpTable to reduce chaos with existing implementation + + protected Map> operators; + + private static final PplOpTable INSTANCE = new PplOpTable(); + + public static PplOpTable getInstance() { + return INSTANCE; + } + + private PplOpTable() { + this.operators = new HashMap<>(); + } + + @Override + public void lookupOperatorOverloads( + SqlIdentifier opName, + @Nullable SqlFunctionCategory category, + SqlSyntax syntax, + List operatorList, + SqlNameMatcher nameMatcher) { + if (!opName.isSimple()) { + return; + } + final String simpleName = opName.getSimple(); + lookUpOperators( + simpleName, + op -> { + if (op.getSyntax() != syntax && op.getSyntax().family != syntax.family) { + // Allow retrieval on exact syntax or family; for example, + // CURRENT_DATETIME has FUNCTION_ID syntax but can also be called with + // both FUNCTION_ID and FUNCTION syntax (e.g. SELECT CURRENT_DATETIME, + // CURRENT_DATETIME('UTC')). + return; + } + if (category != null + && category != category(op) + && !category.isUserDefinedNotSpecificFunction()) { + return; + } + operatorList.add(op); + }); + } + + protected void lookUpOperators(String name, Consumer consumer) { + final BuiltinFunctionName funcNameOpt = sqlFunctionNameToPplFunctionName(name); + if (funcNameOpt == null) { + return; // No function with this name + } + if (!operators.containsKey(funcNameOpt)) { + return; // The function is not registered + } + operators.get(funcNameOpt).forEach(consumer); + } + + /** + * At this stage, the function name of a Calcite's builtin operator is acquired via + * `sqlFunction.getSqlIdentifier()` + * + *

This will return the name in Calcite, instead of that registered in PPL. We use this method + * to convert the Calcite function name to the PPL function name. + */ + private BuiltinFunctionName sqlFunctionNameToPplFunctionName(String name) { + return switch (name.toUpperCase(Locale.ROOT)) { + case "CONVERT" -> BuiltinFunctionName.CONV; + case "ILIKE" -> BuiltinFunctionName.LIKE; + case "CHAR_LENGTH" -> BuiltinFunctionName.LENGTH; + case "NOT_EQUALS", "<>" -> BuiltinFunctionName.XOR; + default -> BuiltinFunctionName.of(name).orElse(null); + }; + } + + protected static SqlFunctionCategory category(SqlOperator operator) { + if (operator instanceof SqlFunction) { + return ((SqlFunction) operator).getFunctionType(); + } else { + return SqlFunctionCategory.SYSTEM; + } + } + + @Override + public List getOperatorList() { + return operators.values().stream() + .flatMap(iterable -> StreamSupport.stream(iterable.spliterator(), false)) + .collect(Collectors.toList()); + } + + public void add(BuiltinFunctionName name, SqlOperator operator) { + ArrayList list = operators.getOrDefault(name, new ArrayList<>()); + list.add(operator); + operators.put(name, list); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/validate/PplTypeCoercion.java b/core/src/main/java/org/opensearch/sql/calcite/validate/PplTypeCoercion.java new file mode 100644 index 00000000000..4407621d7a1 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/validate/PplTypeCoercion.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.validate; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.implicit.TypeCoercionImpl; + +public class PplTypeCoercion extends TypeCoercionImpl { + // A blacklist of coercions that are not allowed in PPL. + // key cannot be cast from values + private static final Map> BLACKLISTED_COERCIONS; + + static { + // Initialize the blacklist for coercions that are not allowed in PPL. + BLACKLISTED_COERCIONS = + Map.of( + SqlTypeFamily.CHARACTER, + Set.of(SqlTypeFamily.NUMERIC), + SqlTypeFamily.STRING, + Set.of(SqlTypeFamily.NUMERIC), + SqlTypeFamily.NUMERIC, + Set.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.STRING)); + } + + public PplTypeCoercion(RelDataTypeFactory typeFactory, SqlValidator validator) { + super(typeFactory, validator); + } + + @Override + public boolean builtinFunctionCoercion( + SqlCallBinding binding, + List operandTypes, + List expectedFamilies) { + assert binding.getOperandCount() == operandTypes.size(); + if (IntStream.range(0, operandTypes.size()) + .anyMatch(i -> isBlacklistedCoercion(operandTypes.get(i), expectedFamilies.get(i)))) { + return false; + } + return super.builtinFunctionCoercion(binding, operandTypes, expectedFamilies); + } + + // This method tries to blacklist coercions that are not allowed in PPL. + private boolean isBlacklistedCoercion(RelDataType operandType, SqlTypeFamily expectedFamily) { + if (BLACKLISTED_COERCIONS.containsKey(expectedFamily)) { + Set blacklistedFamilies = BLACKLISTED_COERCIONS.get(expectedFamily); + if (blacklistedFamilies.contains(operandType.getSqlTypeName().getFamily())) { + return true; + } + } + return false; + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/validate/PplValidator.java b/core/src/main/java/org/opensearch/sql/calcite/validate/PplValidator.java new file mode 100644 index 00000000000..9eb702259a8 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/validate/PplValidator.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.validate; + +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.validate.SqlValidatorCatalogReader; +import org.apache.calcite.sql.validate.SqlValidatorImpl; + +public class PplValidator extends SqlValidatorImpl { + /** + * Creates a validator. + * + * @param opTab Operator table + * @param catalogReader Catalog reader + * @param typeFactory Type factory + * @param config Config + */ + protected PplValidator( + SqlOperatorTable opTab, + SqlValidatorCatalogReader catalogReader, + RelDataTypeFactory typeFactory, + Config config) { + super(opTab, catalogReader, typeFactory, config); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/validate/TypeChecker.java b/core/src/main/java/org/opensearch/sql/calcite/validate/TypeChecker.java new file mode 100644 index 00000000000..0adb4c00c25 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/validate/TypeChecker.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.validate; + +import org.apache.calcite.jdbc.CalcitePrepare; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.server.CalciteServerStatement; +import org.apache.calcite.sql.type.SqlTypeCoercionRule; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.implicit.TypeCoercion; +import org.apache.calcite.tools.FrameworkConfig; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; + +public class TypeChecker { + public static SqlValidator getValidator( + CalciteServerStatement statement, FrameworkConfig config) { + SchemaPlus defaultSchema = config.getDefaultSchema(); + + final CalcitePrepare.Context prepareContext = statement.createPrepareContext(); + final CalciteSchema schema = + defaultSchema != null ? CalciteSchema.from(defaultSchema) : prepareContext.getRootSchema(); + CalciteCatalogReader catalogReader = + new CalciteCatalogReader( + schema.root(), + schema.path(null), + OpenSearchTypeFactory.TYPE_FACTORY, + prepareContext.config()); + SqlValidator.Config validatorConfig = + SqlValidator.Config.DEFAULT + .withTypeCoercionRules(getTypeCoercionRule()) + .withTypeCoercionFactory(TypeChecker::createTypeCoercion) + // TODO: should implement one for Calcite + .withConformance(SqlConformanceEnum.LENIENT); + return new PplValidator( + PplOpTable.getInstance(), + catalogReader, + OpenSearchTypeFactory.TYPE_FACTORY, + validatorConfig); + } + + public static SqlTypeCoercionRule getTypeCoercionRule() { + var defaultMapping = SqlTypeCoercionRule.instance().getTypeMapping(); + // try deleting all coercion rules + return SqlTypeCoercionRule.instance(defaultMapping); + } + + /** Creates a custom TypeCoercion instance for PPL. This can be used as a TypeCoercionFactory. */ + public static TypeCoercion createTypeCoercion( + RelDataTypeFactory typeFactory, SqlValidator validator) { + return new PplTypeCoercion(typeFactory, validator); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index ef8876a9275..f65c10b688f 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -10,6 +10,7 @@ import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.Collections; import java.util.List; import java.util.Optional; import lombok.AllArgsConstructor; @@ -17,15 +18,29 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitDef; +import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.logical.LogicalSort; import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; +import org.apache.calcite.rel.rel2sql.RelToSqlConverter; +import org.apache.calcite.rel.rel2sql.SqlImplementor; +import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.dialect.SparkSqlDialect; import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.util.SqlShuttle; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.sql2rel.StandardConvertletTable; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.Programs; @@ -41,6 +56,7 @@ import org.opensearch.sql.common.setting.Settings.Key; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.exception.CalciteUnsupportedException; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.NonFallbackCalciteException; import org.opensearch.sql.planner.PlanContext; import org.opensearch.sql.planner.Planner; @@ -56,6 +72,7 @@ public class QueryService { private final Analyzer analyzer; private final ExecutionEngine executionEngine; private final Planner planner; + private static final RelToSqlConverter converter = new RelToSqlConverter(SparkSqlDialect.DEFAULT); @Getter(lazy = true) private final CalciteRelNodeVisitor relNodeVisitor = new CalciteRelNodeVisitor(); @@ -102,7 +119,8 @@ public void executeWithCalcite( settings.getSettingValue(Key.QUERY_SIZE_LIMIT), queryType); RelNode relNode = analyze(plan, context); - RelNode optimized = optimize(relNode); + RelNode validated = validate(relNode, context); + RelNode optimized = optimize(validated); RelNode calcitePlan = convertToCalcitePlan(optimized); executionEngine.execute(calcitePlan, context, listener); return null; @@ -135,6 +153,7 @@ public void explainWithCalcite( CalcitePlanContext.create( buildFrameworkConfig(), getQuerySizeLimit(), queryType); RelNode relNode = analyze(plan, context); + RelNode validated = validate(relNode, context); RelNode optimized = optimize(relNode); RelNode calcitePlan = convertToCalcitePlan(optimized); executionEngine.explain(calcitePlan, format, context, listener); @@ -247,6 +266,61 @@ public LogicalPlan analyze(UnresolvedPlan plan, QueryType queryType) { return analyzer.analyze(plan, new AnalysisContext(queryType)); } + private RelNode validate(RelNode relNode, CalcitePlanContext context) { + // Validation + SqlImplementor.Result result = converter.visitRoot(relNode); + SqlNode root = result.asStatement(); + SqlNode rewritten = + root.accept( + new SqlShuttle() { + @Override + public SqlNode visit(SqlIdentifier id) { + // TODO: Maybe not all SqlIdentifier with names of length 2 are db.table + if (id.names.size() == 2) { + // Remove the database qualifier, keep only the table name + return new SqlIdentifier( + Collections.singletonList(id.names.get(1)), id.getParserPosition()); + } + return id; + } + }); + SqlValidator validator = context.getValidator(); + if (rewritten != null) { + try { + String before = rewritten.toString(); + // rewritten will be modified in-place + validator.validate(rewritten); + log.debug("After validation [{}]", rewritten); + String after = rewritten.toString(); + if (before.equals(after)) { + // If the rewritten SQL node is not modified, we can return the original RelNode as is + return relNode; + } + } catch (CalciteContextException e) { + throw new ExpressionEvaluationException(e.getMessage(), e); + } + } else { + log.debug("Failed to rewrite the SQL node before validation: {}", root); + return relNode; + } + + // Convert the validated SqlNode to RelNode + RelOptTable.ViewExpander viewExpander = context.config.getViewExpander(); + RelOptCluster cluster = context.relBuilder.getCluster(); + CalciteCatalogReader catalogReader = + validator.getCatalogReader().unwrap(CalciteCatalogReader.class); + SqlToRelConverter sql2rel = + new SqlToRelConverter( + viewExpander, + validator, + catalogReader, + cluster, + StandardConvertletTable.INSTANCE, + SqlToRelConverter.config()); + RelRoot validatedRelRoot = sql2rel.convertQuery(rewritten, true, true); + return validatedRelRoot.rel; + } + /** Translate {@link LogicalPlan} to {@link PhysicalPlan}. */ public PhysicalPlan plan(LogicalPlan plan) { return planner.plan(plan); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 3da366da827..2a7db7041ec 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -18,13 +18,23 @@ import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlLibraryOperators; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.fun.SqlTrimFunction; import org.apache.calcite.sql.type.CompositeOperandTypeChecker; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; +import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.util.BuiltInMethod; import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; @@ -102,10 +112,162 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { public static final SqlOperator DIVIDE = new DivideFunction().toUDF("DIVIDE"); public static final SqlOperator SHA2 = CryptographicFunction.sha2().toUDF("SHA2"); public static final SqlOperator CIDRMATCH = new CidrMatchFunction().toUDF("CIDRMATCH"); + public static final SqlOperator LOG = + new SqlFunction( + "LOG", + SqlKind.LOG, + ReturnTypes.DOUBLE_NULLABLE, + null, + OperandTypes.NUMERIC_OPTIONAL_NUMERIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite LOG(x, b) to LOG(b, x) + if (call.operandCount() == 2) { + return SqlLibraryOperators.LOG.createCall( + call.getParserPosition(), call.operand(1), call.operand(0)); + } + return super.rewriteCall(validator, call); + } + }; + public static final SqlFunction ATAN = + new SqlFunction( + "ATAN", + SqlKind.OTHER_FUNCTION, + ReturnTypes.DOUBLE_NULLABLE, + null, + OperandTypes.NUMERIC_OPTIONAL_NUMERIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite ATAN(x, y) to ATAN2(y, x) + if (call.operandCount() == 2) { + return SqlStdOperatorTable.ATAN2.createCall( + call.getParserPosition(), call.operand(0), call.operand(1)); + } + return super.rewriteCall(validator, call); + } + }; + + public static final SqlFunction SQRT = + new SqlFunction( + "SQRT", + SqlKind.OTHER_FUNCTION, + ReturnTypes.DOUBLE_NULLABLE, + null, + OperandTypes.NUMERIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite SQRT(x) to POWER(x, 0.5) + return SqlStdOperatorTable.POWER.createCall( + call.getParserPosition(), + call.operand(0), + SqlLiteral.createExactNumeric("0.5", call.getParserPosition())); + } + }; + + // String functions + public static final SqlFunction TRIM = + new SqlFunction( + "TRIM", + SqlKind.TRIM, + ReturnTypes.VARCHAR_NULLABLE, + null, + OperandTypes.CHARACTER, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite TRIM(x) to TRIM(BOTH, ' ', x) + if (call.operandCount() == 1) { + return SqlStdOperatorTable.TRIM.createCall( + call.getParserPosition(), + SqlLiteral.createSymbol(SqlTrimFunction.Flag.BOTH, call.getParserPosition()), + SqlLiteral.createCharString(" ", call.getParserPosition()), + call.operand(0)); + } + return super.rewriteCall(validator, call); + } + }; + + public static final SqlFunction LTRIM = + new SqlFunction( + "LTRIM", + SqlKind.LTRIM, + ReturnTypes.VARCHAR_NULLABLE, + null, + OperandTypes.CHARACTER, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite LTRIM(x) to TRIM(LEADING, ' ', x) + if (call.operandCount() == 1) { + return SqlStdOperatorTable.TRIM.createCall( + call.getParserPosition(), + SqlLiteral.createSymbol(SqlTrimFunction.Flag.LEADING, call.getParserPosition()), + SqlLiteral.createCharString(" ", call.getParserPosition()), + call.operand(0)); + } + return super.rewriteCall(validator, call); + } + }; + + public static final SqlFunction RTRIM = + new SqlFunction( + "RTRIM", + SqlKind.RTRIM, + ReturnTypes.VARCHAR_NULLABLE, + null, + OperandTypes.CHARACTER, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite RTRIM(x) to TRIM(TRAILING, ' ', x) + if (call.operandCount() == 1) { + return SqlStdOperatorTable.TRIM.createCall( + call.getParserPosition(), + SqlLiteral.createSymbol(SqlTrimFunction.Flag.TRAILING, call.getParserPosition()), + SqlLiteral.createCharString(" ", call.getParserPosition()), + call.operand(0)); + } + return super.rewriteCall(validator, call); + } + }; + + public static final SqlFunction STRCMP = + new SqlFunction( + "STRCMP", + SqlKind.OTHER_FUNCTION, + ReturnTypes.INTEGER_NULLABLE, + null, + OperandTypes.STRING_STRING, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite STRCMP(x, y) to STRCMP(y, x) + return SqlLibraryOperators.STRCMP.createCall( + call.getParserPosition(), call.operand(1), call.operand(0)); + } + }; // Condition function public static final SqlOperator EARLIEST = new EarliestFunction().toUDF("EARLIEST"); public static final SqlOperator LATEST = new LatestFunction().toUDF("LATEST"); + public static final SqlFunction XOR = + new SqlFunction( + "XOR", + SqlKind.OTHER_FUNCTION, + ReturnTypes.BOOLEAN_NULLABLE, + null, + OperandTypes.BOOLEAN_BOOLEAN, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + @Override + public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { + // Rewrite XOR(x, y) to NOT_EQUALS(x, y) + return SqlStdOperatorTable.NOT_EQUALS.createCall( + call.getParserPosition(), call.operand(0), call.operand(1)); + } + }; // Datetime function public static final SqlOperator TIMESTAMP = new TimestampFunction().toUDF("TIMESTAMP"); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 363baff3e4c..5f0aeba6dd3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -61,6 +61,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.DAY_OF_YEAR; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DEGREES; import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.DIVIDEFUNCTION; import static org.opensearch.sql.expression.function.BuiltinFunctionName.E; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EARLIEST; import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL; @@ -213,7 +214,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -226,6 +226,7 @@ import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.Stream; +import lombok.Getter; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; @@ -254,6 +255,7 @@ import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.calcite.utils.PlanUtils; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; +import org.opensearch.sql.calcite.validate.PplOpTable; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.executor.QueryType; @@ -329,6 +331,18 @@ default PPLTypeChecker getTypeChecker() { final AggBuilder aggBuilder = new AggBuilder(); aggBuilder.populate(); INSTANCE = new PPLFuncImpTable(builder, aggBuilder); + + // Some operators are registered via register instead of registerOperator + // We add them explicitly so that they can be found during validation + var pplOps = PplOpTable.getInstance(); + pplOps.add(JSON_ARRAY, SqlStdOperatorTable.JSON_ARRAY); + pplOps.add(JSON_OBJECT, SqlStdOperatorTable.JSON_OBJECT); + pplOps.add(INTERNAL_ITEM, SqlStdOperatorTable.ITEM); + // pplOps.add(TYPEOF, ... ); + pplOps.add(IF, SqlStdOperatorTable.CASE); + pplOps.add(NULLIF, SqlStdOperatorTable.CASE); + pplOps.add(IS_EMPTY, SqlStdOperatorTable.IS_EMPTY); + pplOps.add(IS_BLANK, SqlStdOperatorTable.IS_EMPTY); } /** @@ -344,6 +358,7 @@ default PPLTypeChecker getTypeChecker() { * engine should be registered here. This reduces coupling between the core module and particular * storage backends. */ + @Getter private final Map>> externalFunctionRegistry; @@ -352,7 +367,7 @@ default PPLTypeChecker getTypeChecker() { * implementations are independent of any specific data storage, should be registered here * internally. */ - private final ImmutableMap aggFunctionRegistry; + @Getter private final ImmutableMap aggFunctionRegistry; /** * The external agg function registry. Agg Functions whose implementations depend on a specific @@ -445,9 +460,9 @@ public RexNode resolve( List argTypes = Arrays.stream(args).map(RexNode::getType).toList(); try { for (Map.Entry implement : implementList) { - if (implement.getKey().match(functionName.getName(), argTypes)) { - return implement.getValue().resolve(builder, args); - } + // if (implement.getKey().match(functionName.getName(), argTypes)) { + return implement.getValue().resolve(builder, args); + // } } } catch (Exception e) { throw new ExpressionEvaluationException( @@ -516,6 +531,25 @@ void registerOperator(BuiltinFunctionName functionName, SqlOperator operator) { functionName, (RexBuilder builder, RexNode... node) -> builder.makeCall(operator, node)); } + + // Currently, only functions registered via registerOperator is added to PPLOpTable + registerToCatalogWithReplace(functionName, operator); + } + + private static void registerToCatalogWithReplace( + BuiltinFunctionName functionName, SqlOperator operator) { + // replacement contains the real implementations -- some operators are rewritten. + final Map replacement = + Map.of( + LOG, + SqlLibraryOperators.LOG, + TRIM, + SqlStdOperatorTable.TRIM, + STRCMP, + SqlLibraryOperators.STRCMP, + XOR, + SqlStdOperatorTable.NOT_EQUALS); + PplOpTable.getInstance().add(functionName, replacement.getOrDefault(functionName, operator)); } private static SqlOperandTypeChecker extractTypeCheckerFromUDF( @@ -698,6 +732,10 @@ void populate() { registerOperator(MODULUSFUNCTION, PPLBuiltinOperators.MOD); registerOperator(CRC32, PPLBuiltinOperators.CRC32); registerOperator(DIVIDE, PPLBuiltinOperators.DIVIDE); + registerOperator(DIVIDEFUNCTION, PPLBuiltinOperators.DIVIDE); + // SqlStdOperatorTable.SQRT is declared but not implemented. The call to SQRT in Calcite is + // converted to POWER(x, 0.5). + registerOperator(SQRT, PPLBuiltinOperators.SQRT); registerOperator(SHA2, PPLBuiltinOperators.SHA2); registerOperator(CIDRMATCH, PPLBuiltinOperators.CIDRMATCH); registerOperator(INTERNAL_GROK, PPLBuiltinOperators.GROK); @@ -708,6 +746,7 @@ void populate() { registerOperator(SIMPLE_QUERY_STRING, PPLBuiltinOperators.SIMPLE_QUERY_STRING); registerOperator(QUERY_STRING, PPLBuiltinOperators.QUERY_STRING); registerOperator(MULTI_MATCH, PPLBuiltinOperators.MULTI_MATCH); + registerOperator(LOG, PPLBuiltinOperators.LOG); // Register PPL Datetime UDF operator registerOperator(TIMESTAMP, PPLBuiltinOperators.TIMESTAMP); @@ -816,63 +855,16 @@ void populate() { // Register implementation. // Note, make the implementation an individual class if too complex. - register( - TRIM, - createFunctionImpWithTypeChecker( - (builder, arg) -> - builder.makeCall( - SqlStdOperatorTable.TRIM, - builder.makeFlag(Flag.BOTH), - builder.makeLiteral(" "), - arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); + registerOperator(TRIM, PPLBuiltinOperators.TRIM); + registerOperator(LTRIM, PPLBuiltinOperators.LTRIM); + registerOperator(RTRIM, PPLBuiltinOperators.RTRIM); - register( - LTRIM, - createFunctionImpWithTypeChecker( - (builder, arg) -> - builder.makeCall( - SqlStdOperatorTable.TRIM, - builder.makeFlag(Flag.LEADING), - builder.makeLiteral(" "), - arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); - register( - RTRIM, - createFunctionImpWithTypeChecker( - (builder, arg) -> - builder.makeCall( - SqlStdOperatorTable.TRIM, - builder.makeFlag(Flag.TRAILING), - builder.makeLiteral(" "), - arg), - PPLTypeChecker.family(SqlTypeFamily.STRING))); - register( - ATAN, - createFunctionImpWithTypeChecker( - (builder, arg1, arg2) -> builder.makeCall(SqlStdOperatorTable.ATAN2, arg1, arg2), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC))); - register( - STRCMP, - createFunctionImpWithTypeChecker( - (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.STRCMP, arg2, arg1), - PPLTypeChecker.family(SqlTypeFamily.STRING, SqlTypeFamily.STRING))); + registerOperator(ATAN, PPLBuiltinOperators.ATAN); + registerOperator(STRCMP, PPLBuiltinOperators.STRCMP); // SqlStdOperatorTable.SUBSTRING.getOperandTypeChecker is null. We manually create a type // checker for it. - register( - SUBSTRING, - wrapWithCompositeTypeChecker( - SqlStdOperatorTable.SUBSTRING, - (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER), - false)); - register( - SUBSTR, - wrapWithCompositeTypeChecker( - SqlStdOperatorTable.SUBSTRING, - (CompositeOperandTypeChecker) - OperandTypes.STRING_INTEGER.or(OperandTypes.STRING_INTEGER_INTEGER), - false)); + registerOperator(SUBSTRING, SqlStdOperatorTable.SUBSTRING); + registerOperator(SUBSTR, SqlStdOperatorTable.SUBSTRING); // SqlStdOperatorTable.ITEM.getOperandTypeChecker() checks only the first operand instead of // all operands. Therefore, we wrap it with a custom CompositeOperandTypeChecker to check both // operands. @@ -884,37 +876,12 @@ void populate() { OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER) .or(OperandTypes.family(SqlTypeFamily.MAP, SqlTypeFamily.ANY)), false)); - register( - LOG, - createFunctionImpWithTypeChecker( - (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.LOG, arg2, arg1), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC))); - register( - LOG, - createFunctionImpWithTypeChecker( - (builder, arg) -> - builder.makeCall( - SqlLibraryOperators.LOG, - arg, - builder.makeApproxLiteral(BigDecimal.valueOf(Math.E))), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC))); - // SqlStdOperatorTable.SQRT is declared but not implemented. The call to SQRT in Calcite is - // converted to POWER(x, 0.5). - register( - SQRT, - createFunctionImpWithTypeChecker( - (builder, arg) -> - builder.makeCall( - SqlStdOperatorTable.POWER, - arg, - builder.makeApproxLiteral(BigDecimal.valueOf(0.5))), - PPLTypeChecker.family(SqlTypeFamily.NUMERIC))); register( TYPEOF, (FunctionImp1) (builder, arg) -> builder.makeLiteral(getLegacyTypeName(arg.getType(), QueryType.PPL))); - register(XOR, new XOR_FUNC()); + registerOperator(XOR, PPLBuiltinOperators.XOR); // SqlStdOperatorTable.CASE.getOperandTypeChecker is null. We manually create a type checker // for it. The second and third operands are required to be of the same type. If not, // it will throw an IllegalArgumentException with information Can't find leastRestrictive type diff --git a/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java b/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java index fcd7a6a2be5..5fd64702553 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java @@ -5,7 +5,6 @@ package org.opensearch.sql.expression.function; -import java.util.Collections; import java.util.List; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; @@ -48,14 +47,12 @@ public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) { @Override public List paramTypes(RelDataTypeFactory typeFactory) { - // This function is not used in the current context, so we return an empty list. - return Collections.emptyList(); + throw new IllegalStateException("paramTypes is not implemented for UDFOperandMetadata"); } @Override public List paramNames() { - // This function is not used in the current context, so we return an empty list. - return Collections.emptyList(); + throw new IllegalStateException("paramNames is not implemented for UDFOperandMetadata"); } @Override @@ -103,14 +100,14 @@ public SqlTypeFamily getOperandSqlTypeFamily(int iFormalOperand) { @Override public List paramTypes(RelDataTypeFactory typeFactory) { - // This function is not used in the current context, so we return an empty list. - return Collections.emptyList(); + throw new IllegalStateException( + "paramTypes is not supported for CompositeOperandTypeChecker"); } @Override public List paramNames() { - // This function is not used in the current context, so we return an empty list. - return Collections.emptyList(); + throw new IllegalStateException( + "paramNames is not supported for CompositeOperandTypeChecker"); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java index 2e70a210d6e..39368b8b1ff 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java @@ -28,6 +28,7 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.dialect.SparkSqlDialect; import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.test.CalciteAssert; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.Programs; @@ -38,6 +39,7 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.calcite.CalcitePlanContext; import org.opensearch.sql.calcite.CalciteRelNodeVisitor; +import org.opensearch.sql.calcite.validate.TypeChecker; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.setting.Settings.Key; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; @@ -49,12 +51,14 @@ public class CalcitePPLAbstractTest { private final CalciteRelNodeVisitor planTransformer; private final RelToSqlConverter converter; protected final Settings settings; + private SqlValidator validator; public CalcitePPLAbstractTest(CalciteAssert.SchemaSpec... schemaSpecs) { this.config = config(schemaSpecs); this.planTransformer = new CalciteRelNodeVisitor(); this.converter = new RelToSqlConverter(SparkSqlDialect.DEFAULT); this.settings = mock(Settings.class); + this.validator = TypeChecker.getValidator(null, config.build()); } public PPLSyntaxParser pplParser = new PPLSyntaxParser();