Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ public UnifiedQueryPlanner(UnifiedQueryContext context) {
*/
public RelNode plan(String query) {
try {
return context.measure(ANALYZE, () -> strategy.plan(query));
return context.measure(
ANALYZE,
() -> {
RelNode plan = strategy.plan(query);
for (var shuttle : context.getLangSpec().postAnalysisRules()) {
plan = plan.accept(shuttle);
}
return plan;
});
} catch (SyntaxCheckException | UnsupportedOperationException e) {
throw e;
} catch (Exception e) {
Expand Down
32 changes: 25 additions & 7 deletions api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
Expand All @@ -17,8 +18,8 @@

/**
* Language specification defining the dialect the engine accepts. Provides parser configuration,
* validator configuration, and composable {@link LanguageExtension}s that contribute operators and
* post-parse rewrite rules.
* validator configuration, and composable {@link LanguageExtension}s that contribute operators,
* post-parse rewrite rules, and post-analysis rewrite rules.
*
* <p>Implementations define a complete language surface — for example, {@link UnifiedSqlSpec}
* provides ANSI and extended SQL modes. A future PPL spec would implement this same interface once
Expand All @@ -27,8 +28,9 @@
public interface LanguageSpec {

/**
* A composable language extension that contributes operators and post-parse rewrite rules. All
* methods have defaults so extensions only override what they need.
* A composable language extension that contributes operators, post-parse rewrite rules, and
* post-analysis rewrite rules. All methods have defaults so extensions only override what they
* need.
*/
interface LanguageExtension {

Expand All @@ -47,6 +49,14 @@ default SqlOperatorTable operators() {
default List<SqlVisitor<SqlNode>> postParseRules() {
return List.of();
}

/**
* RelNode rewrite rules applied after analysis and before execution. Each rule transforms the
* logical plan tree. Rules within a single extension are applied in list order.
*/
default List<RelShuttle> postAnalysisRules() {
return List.of();
}
}

/**
Expand All @@ -62,9 +72,9 @@ default List<SqlVisitor<SqlNode>> postParseRules() {
SqlValidator.Config validatorConfig();

/**
* Language extensions registered with this spec. Each extension contributes operators and
* post-parse rewrite rules that are composed by {@link #operatorTable()} and {@link
* #postParseRules()}.
* Language extensions registered with this spec. Each extension contributes operators, post-parse
* rewrite rules, and post-analysis rewrite rules composed by {@link #operatorTable()}, {@link
* #postParseRules()}, and {@link #postAnalysisRules()}.
*/
List<LanguageExtension> extensions();

Expand All @@ -86,4 +96,12 @@ default SqlOperatorTable operatorTable() {
default List<SqlVisitor<SqlNode>> postParseRules() {
return extensions().stream().flatMap(ext -> ext.postParseRules().stream()).toList();
}

/**
* All post-analysis RelNode rewrite rules from registered extensions, flattened in registration
* order. Applied to the logical plan after analysis and before execution.
*/
default List<RelShuttle> postAnalysisRules() {
return extensions().stream().flatMap(ext -> ext.postAnalysisRules().stream()).toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.NoArgsConstructor;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
import org.opensearch.sql.api.spec.datetime.DatetimeExtension;

/**
* PPL language specification.
Expand Down Expand Up @@ -37,6 +38,6 @@ public SqlValidator.Config validatorConfig() {

@Override
public List<LanguageExtension> extensions() {
return List.of();
return List.of(new DatetimeExtension());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.api.spec.datetime;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.opensearch.sql.api.spec.LanguageSpec.LanguageExtension;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT;

/** Datetime language extension that normalizes UDT types and casts output for wire-format. */
public class DatetimeExtension implements LanguageExtension {

@Override
public List<RelShuttle> postAnalysisRules() {
return List.of(DatetimeUdtNormalizeRule.INSTANCE, DatetimeOutputCastRule.INSTANCE);
}

/** Maps datetime UDT types to their standard Calcite equivalents. */
@Getter
@RequiredArgsConstructor
enum UdtMapping {
DATE(ExprUDT.EXPR_DATE, SqlTypeName.DATE),
TIME(ExprUDT.EXPR_TIME, SqlTypeName.TIME),
TIMESTAMP(ExprUDT.EXPR_TIMESTAMP, SqlTypeName.TIMESTAMP);

private final ExprUDT udtType;
private final SqlTypeName stdType;

/** Matches a UDT RelDataType to its mapping, or empty if not a datetime UDT. */
static Optional<UdtMapping> fromUdtType(RelDataType type) {
if (!(type instanceof AbstractExprRelDataType<?> e)) {
return Optional.empty();
}
ExprUDT udt = e.getUdt();
return Arrays.stream(values()).filter(u -> u.udtType == udt).findFirst();
}

/** Returns true if the given SqlTypeName is a standard datetime type. */
static boolean isDatetimeType(SqlTypeName typeName) {
return Arrays.stream(values()).anyMatch(u -> u.stdType == typeName);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.api.spec.datetime;

import static org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping.isDatetimeType;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.SqlTypeName;

/** Wraps the root output with CAST(datetime → VARCHAR) for PPL wire-format compatibility. */
class DatetimeOutputCastRule extends RelHomogeneousShuttle {

static final DatetimeOutputCastRule INSTANCE = new DatetimeOutputCastRule();

@Override
public RelNode visit(RelNode other) {
List<RelDataTypeField> fields = other.getRowType().getFieldList();
if (fields.stream().noneMatch(f -> isDatetimeType(f.getType().getSqlTypeName()))) {
return other;
}

RexBuilder rexBuilder = other.getCluster().getRexBuilder();
List<RexNode> projects = new ArrayList<>(fields.size());
List<String> names = new ArrayList<>(fields.size());

for (RelDataTypeField field : fields) {
RexNode ref = rexBuilder.makeInputRef(other, field.getIndex());
if (isDatetimeType(field.getType().getSqlTypeName())) {
projects.add(castToVarchar(rexBuilder, ref, field.getType()));
} else {
projects.add(ref);
}
names.add(field.getName());
}
return LogicalProject.create(other, List.of(), projects, names);
}

private static RexNode castToVarchar(RexBuilder rexBuilder, RexNode expr, RelDataType fieldType) {
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
RelDataType varcharType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.VARCHAR), fieldType.isNullable());
return rexBuilder.makeCast(varcharType, expr);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.api.spec.datetime;

import java.util.Optional;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping;

/**
* Temporary patch that rewrites datetime UDT return types on RexCall nodes to standard Calcite
* types.
*/
class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle {

static final DatetimeUdtNormalizeRule INSTANCE = new DatetimeUdtNormalizeRule();

@Override
public RelNode visit(RelNode other) {
RelNode visited = super.visit(other);
RexBuilder rexBuilder = visited.getCluster().getRexBuilder();
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
return visited.accept(
new RexShuttle() {
@Override
public RexNode visitCall(RexCall call) {
call = (RexCall) super.visitCall(call);
Optional<UdtMapping> mapping = UdtMapping.fromUdtType(call.getType());
if (mapping.isEmpty()) {
return call;
}

RelDataType stdType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(mapping.get().getStdType()),
call.getType().isNullable());
return call.clone(stdType, call.getOperands());
}
});
}
}
Loading
Loading