Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ public UnifiedQueryPlanner(UnifiedQueryContext context) {
*/
public RelNode plan(String query) {
try {
return context.measure(ANALYZE, () -> strategy.plan(query));
return context.measure(
ANALYZE,
() -> {
RelNode plan = strategy.plan(query);
for (var rule : context.getLangSpec().postAnalysisRules()) {
plan = rule.apply(plan);
}
return plan;
});
} catch (SyntaxCheckException | UnsupportedOperationException e) {
throw e;
} catch (Exception e) {
Expand Down
42 changes: 35 additions & 7 deletions api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
Expand All @@ -17,8 +18,8 @@

/**
* Language specification defining the dialect the engine accepts. Provides parser configuration,
* validator configuration, and composable {@link LanguageExtension}s that contribute operators and
* post-parse rewrite rules.
* validator configuration, and composable {@link LanguageExtension}s that contribute operators,
* post-parse rewrite rules, and post-analysis rewrite rules.
*
* <p>Implementations define a complete language surface — for example, {@link UnifiedSqlSpec}
* provides ANSI and extended SQL modes. A future PPL spec would implement this same interface once
Expand All @@ -27,8 +28,18 @@
public interface LanguageSpec {

/**
* A composable language extension that contributes operators and post-parse rewrite rules. All
* methods have defaults so extensions only override what they need.
* A RelNode rewrite rule applied after analysis and before execution. Takes a logical plan and
* returns a rewritten plan.
*/
@FunctionalInterface
interface PostAnalysisRule {
RelNode apply(RelNode plan);
}

/**
* A composable language extension that contributes operators, post-parse rewrite rules, and
* post-analysis rewrite rules. All methods have defaults so extensions only override what they
* need.
*/
interface LanguageExtension {

Expand All @@ -47,6 +58,15 @@ default SqlOperatorTable operators() {
default List<SqlVisitor<SqlNode>> postParseRules() {
return List.of();
}

/**
* RelNode rewrite rules applied after analysis and before execution. Rules within a single
* extension are applied in list order; extensions that depend on ordering should return their
* rules together from one extension rather than relying on cross-extension ordering.
*/
default List<PostAnalysisRule> postAnalysisRules() {
return List.of();
}
}

/**
Expand All @@ -62,9 +82,9 @@ default List<SqlVisitor<SqlNode>> postParseRules() {
SqlValidator.Config validatorConfig();

/**
* Language extensions registered with this spec. Each extension contributes operators and
* post-parse rewrite rules that are composed by {@link #operatorTable()} and {@link
* #postParseRules()}.
* Language extensions registered with this spec. Each extension contributes operators, post-parse
* rewrite rules, and post-analysis rewrite rules composed by {@link #operatorTable()}, {@link
* #postParseRules()}, and {@link #postAnalysisRules()}.
*/
List<LanguageExtension> extensions();

Expand All @@ -86,4 +106,12 @@ default SqlOperatorTable operatorTable() {
default List<SqlVisitor<SqlNode>> postParseRules() {
return extensions().stream().flatMap(ext -> ext.postParseRules().stream()).toList();
}

/**
* All post-analysis RelNode rewrite rules from registered extensions, flattened in registration
* order. Applied to the logical plan after analysis and before execution.
*/
default List<PostAnalysisRule> postAnalysisRules() {
return extensions().stream().flatMap(ext -> ext.postAnalysisRules().stream()).toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.NoArgsConstructor;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
import org.opensearch.sql.api.spec.datetime.DatetimeUdtExtension;

/**
* PPL language specification.
Expand Down Expand Up @@ -37,6 +38,6 @@ public SqlValidator.Config validatorConfig() {

@Override
public List<LanguageExtension> extensions() {
return List.of();
return List.of(new DatetimeUdtExtension());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.api.spec.datetime;

import java.util.List;
import java.util.Optional;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.calcite.avatica.util.DateTimeUtils;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.api.spec.LanguageSpec.LanguageExtension;
import org.opensearch.sql.api.spec.LanguageSpec.PostAnalysisRule;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT;

/**
* Normalizes datetime UDT operations in the logical plan and casts remaining datetime output
* columns to VARCHAR so the wire output matches PPL's String datetime contract.
*
* <p>Contributes three ordered post-analysis rules: {@link DatetimeUdtLiteralCoercionRule} casts
* VARCHAR operands that appear alongside standard datetime operands in non-UDF operators (so the
* subsequent rules see homogeneous types); {@link DatetimeUdtNormalizeRule} rewrites UDT calls;
* {@link DatetimeUdtOutputCastRule} wraps the root with a varchar projection. The cast depends on
* the normalized row type, so all three rules live in a single extension to keep their ordering
* encapsulated.
*/
public class DatetimeUdtExtension implements LanguageExtension {

@Override
public List<PostAnalysisRule> postAnalysisRules() {
return List.of(
new DatetimeUdtLiteralCoercionRule(),
new DatetimeUdtNormalizeRule(),
new DatetimeUdtOutputCastRule());
}

/** Maps a datetime UDT to its standard Calcite equivalent with value conversion methods. */
@Getter
@RequiredArgsConstructor
enum UdtMapping {
DATE(ExprUDT.EXPR_DATE, SqlTypeName.DATE, "dateStringToUnixDate", "unixDateToString"),
TIME(ExprUDT.EXPR_TIME, SqlTypeName.TIME, "timeStringToUnixDate", "unixTimeToString"),
TIMESTAMP(
ExprUDT.EXPR_TIMESTAMP,
SqlTypeName.TIMESTAMP,
"timestampStringToUnixDate",
"unixTimestampToString");

private final ExprUDT udtType;
private final SqlTypeName stdType;
private final String toStdMethod;
private final String fromStdMethod;

/** Matches a UDT type to its mapping. */
static Optional<UdtMapping> fromUdtType(RelDataType type) {
if (!(type instanceof AbstractExprRelDataType<?> e)) return Optional.empty();
ExprUDT udt = e.getUdt();
for (UdtMapping u : values()) {
if (u.udtType == udt) return Optional.of(u);
}
return Optional.empty();
}

/** Matches a standard Calcite type to its mapping. */
static Optional<UdtMapping> fromStdType(RelDataType type) {
SqlTypeName name = type.getSqlTypeName();
for (UdtMapping u : values()) {
if (u.stdType == name) return Optional.of(u);
}
return Optional.empty();
}

RelDataType toStdType(RexBuilder rexBuilder, boolean nullable) {
return rexBuilder
.getTypeFactory()
.createTypeWithNullability(rexBuilder.getTypeFactory().createSqlType(stdType), nullable);
}

/** UDT value (String) → standard value (int/long). */
Expression toStdValue(Expression result) {
return Expressions.call(
DateTimeUtils.class, toStdMethod, Expressions.call(result, "toString"));
}

/** Standard value (int/long) → UDT value (String). */
Expression fromStdValue(Expression operand) {
return Expressions.call(DateTimeUtils.class, fromStdMethod, operand);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.api.spec.datetime;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.api.spec.LanguageSpec.PostAnalysisRule;
import org.opensearch.sql.api.spec.datetime.DatetimeUdtExtension.UdtMapping;

/**
* Coerces VARCHAR literals and references that appear alongside standard datetime operands inside
* non-UDF operators (comparisons, IN, BETWEEN/SEARCH, COALESCE) by wrapping the VARCHAR side in
* {@code CAST(... AS DATE|TIME|TIMESTAMP)}. This closes the gap left by {@link
* DatetimeUdtNormalizeRule}, which only rewrites operators backed by {@code
* ImplementableUDFunction}.
*
* <p>Only operand sub-trees inside {@code RexCall} nodes are modified; no {@code RelNode} row type
* is changed and no {@code RexInputRef} slot identity is altered. This keeps the rewrite safe
* against Calcite's cached {@code RexInputRef} types (unlike an in-place ref rewrite that would
* invalidate parent nodes).
*/
public class DatetimeUdtLiteralCoercionRule implements PostAnalysisRule {

@Override
public RelNode apply(RelNode plan) {
RexBuilder rexBuilder = plan.getCluster().getRexBuilder();
return plan.accept(
new RelHomogeneousShuttle() {
@Override
public RelNode visit(RelNode other) {
return super.visit(other).accept(new LiteralCoercionShuttle(rexBuilder));
}
});
}

private static class LiteralCoercionShuttle extends RexShuttle {

private final RexBuilder rexBuilder;

LiteralCoercionShuttle(RexBuilder rexBuilder) {
this.rexBuilder = rexBuilder;
}

@Override
public RexNode visitCall(RexCall call) {
RexCall visited = (RexCall) super.visitCall(call);
if (!isTargetOperator(visited)) {
return visited;
}
Optional<UdtMapping> datetime = findDatetimeOperand(visited);
if (datetime.isEmpty()) {
return visited;
}
List<RexNode> coerced = coerceVarcharOperands(visited.getOperands(), datetime.get());
if (coerced.equals(visited.getOperands())) {
return visited;
}
return visited.clone(visited.getType(), coerced);
}

/** Operators where we perform VARCHAR ↔ datetime operand coercion. */
private static boolean isTargetOperator(RexCall call) {
SqlKind kind = call.getKind();
return kind == SqlKind.EQUALS
|| kind == SqlKind.NOT_EQUALS
|| kind == SqlKind.GREATER_THAN
|| kind == SqlKind.GREATER_THAN_OR_EQUAL
|| kind == SqlKind.LESS_THAN
|| kind == SqlKind.LESS_THAN_OR_EQUAL
|| kind == SqlKind.IN
|| kind == SqlKind.SEARCH
|| kind == SqlKind.BETWEEN
|| kind == SqlKind.COALESCE;
}

/** Returns the first operand whose type is a standard Calcite datetime. */
private static Optional<UdtMapping> findDatetimeOperand(RexCall call) {
for (RexNode op : call.getOperands()) {
Optional<UdtMapping> m = UdtMapping.fromStdType(op.getType());
if (m.isPresent()) {
return m;
}
}
return Optional.empty();
}

/** Wraps every VARCHAR/CHAR operand in {@code CAST(... AS <datetime>)}. */
private List<RexNode> coerceVarcharOperands(List<RexNode> operands, UdtMapping datetime) {
List<RexNode> coerced = new ArrayList<>(operands.size());
boolean changed = false;
for (RexNode op : operands) {
if (isCharType(op.getType())) {
RelDataType target = datetime.toStdType(rexBuilder, op.getType().isNullable());
coerced.add(rexBuilder.makeCast(target, op));
changed = true;
} else {
coerced.add(op);
}
}
return changed ? coerced : operands;
}

private static boolean isCharType(RelDataType type) {
SqlTypeName name = type.getSqlTypeName();
return name == SqlTypeName.VARCHAR || name == SqlTypeName.CHAR;
}
}
}
Loading
Loading