Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support select with statement sql bind and add bind test case #34141

Merged
merged 9 commits into from
Dec 24, 2024
Merged
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
1. SQL Binder: Support create index statement sql bind - [#34112](https://github.com/apache/shardingsphere/pull/34112)
1. SQL Parser: Support MySQL update with statement parse - [#34126](https://github.com/apache/shardingsphere/pull/34126)
1. SQL Binder: Remove TablesContext#findTableNames method and implement select order by, group by bind logic - [#34123](https://github.com/apache/shardingsphere/pull/34123)
1. SQL Binder: Support select with statement sql bind and add bind test case - [#34141](https://github.com/apache/shardingsphere/pull/34141)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.shardingsphere.encrypt.checker.sql;

import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptInsertSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.orderby.EncryptOrderByItemSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.predicate.EncryptPredicateColumnSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptInsertSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.constant.EncryptOrder;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
import org.apache.shardingsphere.infra.database.core.metadata.database.DialectDatabaseMetaData;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
Expand Down Expand Up @@ -150,6 +151,12 @@ private static SimpleTableSegmentBinderContext createSimpleTableBinderContext(fi
if (binderContext.getSqlStatement() instanceof CreateTableStatement) {
return new SimpleTableSegmentBinderContext(createProjectionSegments((CreateTableStatement) binderContext.getSqlStatement(), databaseName, schemaName, tableName));
}
CaseInsensitiveString caseInsensitiveTableName = new CaseInsensitiveString(tableName.getValue());
if (binderContext.getExternalTableBinderContexts().containsKey(caseInsensitiveTableName)) {
TableSegmentBinderContext tableSegmentBinderContext = binderContext.getExternalTableBinderContexts().get(caseInsensitiveTableName).iterator().next();
return new SimpleTableSegmentBinderContext(
SubqueryTableBindUtils.createSubqueryProjections(tableSegmentBinderContext.getProjectionSegments(), tableName, binderContext.getSqlStatement().getDatabaseType()));
}
return new SimpleTableSegmentBinderContext(Collections.emptyList());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.with;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.type.SubqueryTableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment;

import java.util.stream.Collectors;

/**
* Common table expression segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class CommonTableExpressionSegmentBinder {

/**
* Bind common table expression segment.
*
* @param segment common table expression segment
* @param binderContext SQL statement binder context
* @param recursive recursive
* @return bound common table expression segment
*/
public static CommonTableExpressionSegment bind(final CommonTableExpressionSegment segment, final SQLStatementBinderContext binderContext, final boolean recursive) {
if (recursive && segment.getAliasName().isPresent()) {
binderContext.getExternalTableBinderContexts().put(new CaseInsensitiveString(segment.getAliasName().get()),
new SimpleTableSegmentBinderContext(segment.getColumns().stream().map(ColumnProjectionSegment::new).collect(Collectors.toList())));
}
SubqueryTableSegment subqueryTableSegment = new SubqueryTableSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getSubquery());
subqueryTableSegment.setAlias(segment.getAliasSegment());
SubqueryTableSegment boundSubquerySegment =
SubqueryTableSegmentBinder.bind(subqueryTableSegment, binderContext, LinkedHashMultimap.create(), binderContext.getExternalTableBinderContexts());
CommonTableExpressionSegment result = new CommonTableExpressionSegment(
segment.getStartIndex(), segment.getStopIndex(), boundSubquerySegment.getAliasSegment().orElse(null), boundSubquerySegment.getSubquery());
// TODO bind with columns
result.getColumns().addAll(segment.getColumns());
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.with;

import com.cedarsoftware.util.CaseInsensitiveMap;
import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.base.Strings;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
* With segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class WithSegmentBinder {

/**
* Bind with segment.
*
* @param segment with segment
* @param binderContext SQL statement binder context
* @param externalTableBinderContexts external table binder contexts
* @return bound with segment
*/
public static WithSegment bind(final WithSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> externalTableBinderContexts) {
Collection<CommonTableExpressionSegment> boundCommonTableExpressions = new LinkedList<>();
for (CommonTableExpressionSegment each : segment.getCommonTableExpressions()) {
CommonTableExpressionSegment boundCommonTableExpression = CommonTableExpressionSegmentBinder.bind(each, binderContext, segment.isRecursive());
boundCommonTableExpressions.add(boundCommonTableExpression);
if (segment.isRecursive() && each.getAliasName().isPresent()) {
externalTableBinderContexts.removeAll(new CaseInsensitiveString(each.getAliasName().get()));
}
bindWithColumns(each.getColumns(), boundCommonTableExpression);
each.getAliasName().ifPresent(optional -> externalTableBinderContexts.put(new CaseInsensitiveString(optional), createWithTableBinderContext(boundCommonTableExpression)));
}
return new WithSegment(segment.getStartIndex(), segment.getStopIndex(), boundCommonTableExpressions);
}

private static SimpleTableSegmentBinderContext createWithTableBinderContext(final CommonTableExpressionSegment commonTableExpressionSegment) {
return new SimpleTableSegmentBinderContext(commonTableExpressionSegment.getSubquery().getSelect().getProjections().getProjections());
}

private static void bindWithColumns(final Collection<ColumnSegment> columns, final CommonTableExpressionSegment boundCommonTableExpression) {
if (columns.isEmpty()) {
return;
}
Map<String, ColumnProjectionSegment> columnProjections = extractWithSubqueryColumnProjections(boundCommonTableExpression);
columns.forEach(each -> {
ColumnProjectionSegment projectionSegment = columnProjections.get(each.getIdentifier().getValue());
if (null != projectionSegment) {
each.setColumnBoundInfo(createColumnSegmentBoundInfo(each, projectionSegment.getColumn()));
}
});
}

private static Map<String, ColumnProjectionSegment> extractWithSubqueryColumnProjections(final CommonTableExpressionSegment boundCommonTableExpression) {
Map<String, ColumnProjectionSegment> result = new CaseInsensitiveMap<>();
Collection<ProjectionSegment> projections = boundCommonTableExpression.getSubquery().getSelect().getProjections().getProjections();
projections.forEach(each -> extractWithSubqueryColumnProjections(each, result));
return result;
}

private static void extractWithSubqueryColumnProjections(final ProjectionSegment projectionSegment, final Map<String, ColumnProjectionSegment> result) {
if (projectionSegment instanceof ColumnProjectionSegment) {
result.put(getColumnName((ColumnProjectionSegment) projectionSegment), (ColumnProjectionSegment) projectionSegment);
}
if (projectionSegment instanceof ShorthandProjectionSegment) {
((ShorthandProjectionSegment) projectionSegment).getActualProjectionSegments().forEach(eachProjection -> {
if (eachProjection instanceof ColumnProjectionSegment) {
result.put(getColumnName((ColumnProjectionSegment) eachProjection), (ColumnProjectionSegment) eachProjection);
}
});
}
}

private static String getColumnName(final ColumnProjectionSegment columnProjection) {
return columnProjection.getAliasName().orElse(columnProjection.getColumn().getIdentifier().getValue());
}

private static ColumnSegmentBoundInfo createColumnSegmentBoundInfo(final ColumnSegment segment, final ColumnSegment inputColumnSegment) {
IdentifierValue originalDatabase = null == inputColumnSegment ? null : inputColumnSegment.getColumnBoundInfo().getOriginalDatabase();
IdentifierValue originalSchema = null == inputColumnSegment ? null : inputColumnSegment.getColumnBoundInfo().getOriginalSchema();
IdentifierValue segmentOriginalTable = segment.getColumnBoundInfo().getOriginalTable();
IdentifierValue originalTable = Strings.isNullOrEmpty(segmentOriginalTable.getValue())
? Optional.ofNullable(inputColumnSegment).map(optional -> optional.getColumnBoundInfo().getOriginalTable()).orElse(segmentOriginalTable)
: segmentOriginalTable;
IdentifierValue segmentOriginalColumn = segment.getColumnBoundInfo().getOriginalColumn();
IdentifierValue originalColumn = Optional.ofNullable(inputColumnSegment).map(optional -> optional.getColumnBoundInfo().getOriginalColumn()).orElse(segmentOriginalColumn);
return new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(originalDatabase, originalSchema), originalTable, originalColumn);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.shardingsphere.infra.binder.engine.segment.predicate.HavingSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.predicate.WhereSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.ProjectionsSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.with.WithSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
Expand Down Expand Up @@ -59,6 +60,7 @@ public SelectStatementBinder() {
public SelectStatement bind(final SelectStatement sqlStatement, final SQLStatementBinderContext binderContext) {
SelectStatement result = copy(sqlStatement);
Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts = LinkedHashMultimap.create();
sqlStatement.getWithSegment().ifPresent(optional -> result.setWithSegment(WithSegmentBinder.bind(optional, binderContext, binderContext.getExternalTableBinderContexts())));
Optional<TableSegment> boundTableSegment = sqlStatement.getFrom().map(optional -> TableSegmentBinder.bind(optional, binderContext, tableBinderContexts, outerTableBinderContexts));
boundTableSegment.ifPresent(result::setFrom);
result.setProjections(ProjectionsSegmentBinder.bind(sqlStatement.getProjections(), binderContext, boundTableSegment.orElse(null), tableBinderContexts, outerTableBinderContexts));
Expand All @@ -71,7 +73,6 @@ public SelectStatement bind(final SelectStatement sqlStatement, final SQLStateme
sqlStatement.getOrderBy().ifPresent(optional -> result.setOrderBy(
OrderBySegmentBinder.bind(optional, binderContext, currentTableBinderContexts, tableBinderContexts, outerTableBinderContexts)));
sqlStatement.getHaving().ifPresent(optional -> result.setHaving(HavingSegmentBinder.bind(optional, binderContext, currentTableBinderContexts, outerTableBinderContexts)));
// TODO support other segment bind in select statement
return result;
}

Expand All @@ -90,7 +91,6 @@ private SelectStatement copy(final SelectStatement sqlStatement) {
sqlStatement.getWindow().ifPresent(result::setWindow);
sqlStatement.getModelSegment().ifPresent(result::setModelSegment);
sqlStatement.getSubqueryType().ifPresent(result::setSubqueryType);
sqlStatement.getWithSegment().ifPresent(result::setWithSegment);
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
result.getVariableNames().addAll(sqlStatement.getVariableNames());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.infra.binder.engine.segment.from.TableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.predicate.WhereSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.with.WithSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
Expand All @@ -38,6 +39,7 @@ public final class UpdateStatementBinder implements SQLStatementBinder<UpdateSta
public UpdateStatement bind(final UpdateStatement sqlStatement, final SQLStatementBinderContext binderContext) {
UpdateStatement result = copy(sqlStatement);
Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts = LinkedHashMultimap.create();
sqlStatement.getWithSegment().ifPresent(optional -> result.setWithSegment(WithSegmentBinder.bind(optional, binderContext, binderContext.getExternalTableBinderContexts())));
result.setTable(TableSegmentBinder.bind(sqlStatement.getTable(), binderContext, tableBinderContexts, LinkedHashMultimap.create()));
sqlStatement.getFrom().ifPresent(optional -> result.setFrom(TableSegmentBinder.bind(optional, binderContext, tableBinderContexts, LinkedHashMultimap.create())));
sqlStatement.getAssignmentSegment().ifPresent(optional -> result.setSetAssignment(AssignmentSegmentBinder.bind(optional, binderContext, tableBinderContexts, LinkedHashMultimap.create())));
Expand All @@ -50,7 +52,6 @@ private UpdateStatement copy(final UpdateStatement sqlStatement) {
UpdateStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
sqlStatement.getOrderBy().ifPresent(result::setOrderBy);
sqlStatement.getLimit().ifPresent(result::setLimit);
sqlStatement.getWithSegment().ifPresent(result::setWithSegment);
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
return result;
Expand Down
Loading
Loading