diff --git a/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java b/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java index edf9ae50e18..4ab47f36bb3 100644 --- a/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java +++ b/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java @@ -60,7 +60,15 @@ public UnifiedQueryPlanner(UnifiedQueryContext context) { */ public RelNode plan(String query) { try { - return context.measure(ANALYZE, () -> strategy.plan(query)); + return context.measure( + ANALYZE, + () -> { + RelNode plan = strategy.plan(query); + for (var rule : context.getLangSpec().postAnalysisRules()) { + plan = rule.apply(plan); + } + return plan; + }); } catch (SyntaxCheckException | UnsupportedOperationException e) { throw e; } catch (Exception e) { diff --git a/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java b/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java index 89167dc27a5..49f905f50a1 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -17,8 +18,8 @@ /** * Language specification defining the dialect the engine accepts. Provides parser configuration, - * validator configuration, and composable {@link LanguageExtension}s that contribute operators and - * post-parse rewrite rules. + * validator configuration, and composable {@link LanguageExtension}s that contribute operators, + * post-parse rewrite rules, and post-analysis rewrite rules. * *

Implementations define a complete language surface — for example, {@link UnifiedSqlSpec} * provides ANSI and extended SQL modes. A future PPL spec would implement this same interface once @@ -27,8 +28,18 @@ public interface LanguageSpec { /** - * A composable language extension that contributes operators and post-parse rewrite rules. All - * methods have defaults so extensions only override what they need. + * A RelNode rewrite rule applied after analysis and before execution. Takes a logical plan and + * returns a rewritten plan. + */ + @FunctionalInterface + interface PostAnalysisRule { + RelNode apply(RelNode plan); + } + + /** + * A composable language extension that contributes operators, post-parse rewrite rules, and + * post-analysis rewrite rules. All methods have defaults so extensions only override what they + * need. */ interface LanguageExtension { @@ -47,6 +58,15 @@ default SqlOperatorTable operators() { default List> postParseRules() { return List.of(); } + + /** + * RelNode rewrite rules applied after analysis and before execution. Rules within a single + * extension are applied in list order; extensions that depend on ordering should return their + * rules together from one extension rather than relying on cross-extension ordering. + */ + default List postAnalysisRules() { + return List.of(); + } } /** @@ -62,9 +82,9 @@ default List> postParseRules() { SqlValidator.Config validatorConfig(); /** - * Language extensions registered with this spec. Each extension contributes operators and - * post-parse rewrite rules that are composed by {@link #operatorTable()} and {@link - * #postParseRules()}. + * Language extensions registered with this spec. Each extension contributes operators, post-parse + * rewrite rules, and post-analysis rewrite rules composed by {@link #operatorTable()}, {@link + * #postParseRules()}, and {@link #postAnalysisRules()}. */ List extensions(); @@ -86,4 +106,12 @@ default SqlOperatorTable operatorTable() { default List> postParseRules() { return extensions().stream().flatMap(ext -> ext.postParseRules().stream()).toList(); } + + /** + * All post-analysis RelNode rewrite rules from registered extensions, flattened in registration + * order. Applied to the logical plan after analysis and before execution. + */ + default List postAnalysisRules() { + return extensions().stream().flatMap(ext -> ext.postAnalysisRules().stream()).toList(); + } } diff --git a/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java b/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java index 763f6ded540..781f75bf0bd 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java @@ -10,6 +10,7 @@ import lombok.NoArgsConstructor; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.validate.SqlValidator; +import org.opensearch.sql.api.spec.datetime.DatetimeUdtExtension; /** * PPL language specification. @@ -37,6 +38,6 @@ public SqlValidator.Config validatorConfig() { @Override public List extensions() { - return List.of(); + return List.of(new DatetimeUdtExtension()); } } diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtExtension.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtExtension.java new file mode 100644 index 00000000000..43af99bdeb8 --- /dev/null +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtExtension.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; + +import java.util.List; +import java.util.Map; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.api.spec.LanguageSpec.LanguageExtension; +import org.opensearch.sql.api.spec.LanguageSpec.PostAnalysisRule; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT; + +/** + * Bridges PPL's datetime UDT semantics with standard Calcite datetime types in the unified query + * API, so PPL queries over standard schemas behave the same as PPL queries over OpenSearch. + */ +public class DatetimeUdtExtension implements LanguageExtension { + + @Override + public List postAnalysisRules() { + return List.of(new CoercionRule()); + } + + /** + * Wraps every standard DATE/TIME/TIMESTAMP expression with {@code CAST(x AS )}. UDT + * expressions (already backed by the same base type) are left alone. + */ + static class CoercionRule implements PostAnalysisRule { + + /** Standard datetime type → corresponding PPL UDT. */ + private static final Map STD_TO_UDT = + Map.of( + SqlTypeName.DATE, TYPE_FACTORY.createUDT(ExprUDT.EXPR_DATE), + SqlTypeName.TIME, TYPE_FACTORY.createUDT(ExprUDT.EXPR_TIME), + SqlTypeName.TIMESTAMP, TYPE_FACTORY.createUDT(ExprUDT.EXPR_TIMESTAMP)); + + @Override + public RelNode apply(RelNode plan) { + return plan.accept( + new RelHomogeneousShuttle() { + @Override + public RelNode visit(RelNode other) { + RelNode visited = super.visit(other); + RexBuilder rexBuilder = visited.getCluster().getRexBuilder(); + return visited.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef ref) { + return wrap(ref); + } + + @Override + public RexNode visitLiteral(RexLiteral literal) { + return wrap(literal); + } + + @Override + public RexNode visitCall(RexCall call) { + return wrap(super.visitCall(call)); + } + + private RexNode wrap(RexNode node) { + RelDataType udt = STD_TO_UDT.get(node.getType().getSqlTypeName()); + if (udt == null) { + return node; + } + return rexBuilder.makeCast( + rexBuilder + .getTypeFactory() + .createTypeWithNullability(udt, node.getType().isNullable()), + node); + } + }); + } + }); + } + } +} diff --git a/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtExtensionTest.java b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtExtensionTest.java new file mode 100644 index 00000000000..ea36be2f24f --- /dev/null +++ b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtExtensionTest.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import static java.sql.Types.INTEGER; +import static java.sql.Types.VARCHAR; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.time.LocalDate; +import java.time.LocalTime; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.schema.Table; +import org.apache.calcite.sql.type.SqlTypeName; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.sql.api.ResultSetAssertion; +import org.opensearch.sql.api.UnifiedQueryTestBase; +import org.opensearch.sql.api.compiler.UnifiedQueryCompiler; + +public class DatetimeUdtExtensionTest extends UnifiedQueryTestBase implements ResultSetAssertion { + + private UnifiedQueryCompiler compiler; + + @Before + public void setUp() { + super.setUp(); + compiler = new UnifiedQueryCompiler(context); + } + + @Override + protected Table createEmployeesTable() { + return SimpleTable.builder() + .col("name", SqlTypeName.VARCHAR) + .col("hire_date", SqlTypeName.DATE) + .col("login_time", SqlTypeName.TIME) + .col("updated_at", SqlTypeName.TIMESTAMP) + .row( + new Object[] { + "Alice", + (int) LocalDate.of(2020, 3, 15).toEpochDay(), + (int) (LocalTime.of(9, 30).toNanoOfDay() / 1_000_000), + 1705312200000L + }) + .build(); + } + + private ResultSet execute(RelNode plan) throws Exception { + PreparedStatement stmt = compiler.compile(plan); + return stmt.executeQuery(); + } + + @Test + public void castsStdDatetimeAsPplUdfOperand() throws Exception { + RelNode plan = + givenQuery("source = catalog.employees | eval y = YEAR(hire_date) | fields y") + .assertPlan( + """ + LogicalProject(y=[YEAR(CAST($1):EXPR_DATE VARCHAR NOT NULL)]) + LogicalTableScan(table=[[catalog, employees]]) + """) + .plan(); + verify(execute(plan)).expectSchema(col("y", INTEGER)).expectData(row(2020)); + } + + @Test + public void castsStdDatetimeInUdtComparison() throws Exception { + RelNode plan = + givenQuery( + "source = catalog.employees | where hire_date > DATE('2020-01-01') | fields name") + .assertPlan( + """ + LogicalProject(name=[$0]) + LogicalFilter(condition=[>(CAST($1):EXPR_DATE VARCHAR NOT NULL, DATE('2020-01-01':VARCHAR))]) + LogicalTableScan(table=[[catalog, employees]]) + """) + .plan(); + verify(execute(plan)).expectSchema(col("name", VARCHAR)).expectData(row("Alice")); + } + + @Test + public void leavesUdtReturnTypeUntouched() throws Exception { + RelNode plan = + givenQuery("source = catalog.employees | eval d = LAST_DAY(hire_date) | fields d") + .assertPlan( + """ + LogicalProject(d=[LAST_DAY(CAST($1):EXPR_DATE VARCHAR NOT NULL)]) + LogicalTableScan(table=[[catalog, employees]]) + """) + .plan(); + verify(execute(plan)).expectSchema(col("d", VARCHAR)).expectData(row("2020-03-31")); + } + + @Test + public void castsBareStdDatetimeInProjection() throws Exception { + RelNode plan = + givenQuery("source = catalog.employees | fields hire_date, login_time") + .assertPlan( + """ + LogicalProject(hire_date=[CAST($1):EXPR_DATE VARCHAR NOT NULL], login_time=[CAST($2):EXPR_TIME VARCHAR NOT NULL]) + LogicalTableScan(table=[[catalog, employees]]) + """) + .plan(); + verify(execute(plan)) + .expectSchema(col("hire_date", VARCHAR), col("login_time", VARCHAR)) + .expectData(row("2020-03-15", "09:30:00")); + } + + @Test + public void leavesNonDatetimeUntouched() throws Exception { + RelNode plan = + givenQuery("source = catalog.employees | fields name") + .assertPlan( + """ + LogicalProject(name=[$0]) + LogicalTableScan(table=[[catalog, employees]]) + """) + .plan(); + verify(execute(plan)).expectSchema(col("name", VARCHAR)).expectData(row("Alice")); + } +}