Skip to content

Commit

Permalink
[CALCITE-6824] Subquery in join conditions rewrite fails if referenci…
Browse files Browse the repository at this point in the history
…ng a column from the right-hand side table
  • Loading branch information
suibianwanwank authored and asolimando committed Feb 19, 2025
1 parent dd49d9f commit 130c8dc
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,56 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) {
final RelOptUtil.Logic logic =
LogicVisitor.find(RelOptUtil.Logic.TRUE,
ImmutableList.of(join.getCondition()), e);
builder.push(join.getLeft());
builder.push(join.getRight());
final int fieldCount = join.getRowType().getFieldCount();
final Set<CorrelationId> variablesSet =
RelOptUtil.getVariablesUsed(e.rel);
final RexNode target =
rule.apply(e, variablesSet, logic, builder, 2, fieldCount, 0);
final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target);
builder.join(join.getJoinType(), shuttle.apply(join.getCondition()));
builder.project(fields(builder, join.getRowType().getFieldCount()));

ImmutableBitSet inputSet = RelOptUtil.InputFinder.bits(e.getOperands(), null);
int nFieldsLeft = join.getLeft().getRowType().getFieldCount();
int nFieldsRight = join.getRight().getRowType().getFieldCount();


boolean inputIntersectsLeftSide = inputSet.intersects(ImmutableBitSet.range(0, nFieldsLeft));
boolean inputIntersectsRightSide =
inputSet.intersects(ImmutableBitSet.range(nFieldsLeft, nFieldsLeft + nFieldsRight));
if (inputIntersectsLeftSide && inputIntersectsRightSide) {
// The current existential rewrite needs to make join with one side of the origin join and
// generate a new condition to replace the on clause. But for RexNode whose operands are
// on either side of the join, we can't push them into join. So this rewriting is not
// supported.
return;
}

final Set<CorrelationId> variablesSet = RelOptUtil.getVariablesUsed(e.rel);
if (inputIntersectsLeftSide) {
builder.push(join.getLeft());

final RexNode target =
rule.apply(e, variablesSet, logic, builder, 1, nFieldsLeft, 0);
final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target);

final RexNode newCond =
shuttle.apply(
RexUtil.shift(join.getCondition(), nFieldsLeft,
builder.fields().size() - nFieldsLeft));
builder.push(join.getRight());
builder.join(join.getJoinType(), newCond);

final int nFields = builder.fields().size();
ImmutableList<RexNode> fields =
builder.fields(ImmutableBitSet.range(0, nFieldsLeft)
.union(ImmutableBitSet.range(nFields - nFieldsRight, nFields)));
builder.project(fields);
} else {
builder.push(join.getLeft());
builder.push(join.getRight());

final int nFields = join.getRowType().getFieldCount();
final RexNode target =
rule.apply(e, variablesSet, logic, builder, 2, nFields, 0);
final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target);

builder.join(join.getJoinType(), shuttle.apply(join.getCondition()));
builder.project(fields(builder, nFields));
}

call.transformTo(builder.build());
}

Expand Down
39 changes: 39 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9116,6 +9116,45 @@ public interface Config extends RelRule.Config {
.check();
}

/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6824">[CALCITE-6824]
* Subquery in join conditions rewrite fails if referencing a column
* from the right-hand side table</a>. */
@Test void testJoinSubQueryRemoveRuleWithNotIn() {
final String sql = "SELECT empno FROM emp JOIN dept on "
+ "emp.deptno not in (SELECT deptno FROM dept)";
sql(sql)
.withRule(CoreRules.JOIN_SUB_QUERY_TO_CORRELATE)
.check();
}

/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6824">[CALCITE-6824]
* Subquery in join conditions rewrite fails if referencing a column
* from the right-hand side table</a>. */
@Test void testJoinSubQueryRemoveRuleWithQuantifierSome() {
final String sql = "SELECT empno FROM emp JOIN dept on "
+ "emp.deptno >= SOME(SELECT deptno FROM dept)";
sql(sql)
.withRule(CoreRules.JOIN_SUB_QUERY_TO_CORRELATE)
.check();
}

/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6824">[CALCITE-6824]
* Subquery in join conditions rewrite fails if referencing a column
* from the right-hand side table</a>. */
@Test void testJoinSubQueryRewriteWithBothSidesColumns() {
final String sql = "SELECT empno FROM emp JOIN dept on "
+ "emp.deptno + dept.deptno >= SOME(SELECT deptno FROM dept)";
sql(sql)
.withRule(CoreRules.JOIN_SUB_QUERY_TO_CORRELATE)
.checkUnchanged();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-2295">[CALCITE-2295]
* Correlated SubQuery with Project will generate error plan</a>. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6282,6 +6282,84 @@ LogicalProject(SAL=[$5])
LogicalProject(SAL=[$5], $f9=[=($5, 4)])
LogicalFilter(condition=[AND(=($7, 20), >($5, 1000))])
LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinSubQueryRemoveRuleWithNotIn">
<Resource name="sql">
<![CDATA[SELECT empno FROM emp JOIN dept on emp.deptno not in (SELECT deptno FROM dept)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalJoin(condition=[NOT(IN($7, {
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
}))], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$13], NAME=[$14])
LogicalJoin(condition=[OR(=($9, 0), AND(IS NULL($12), >=($10, $9)))], joinType=[inner])
LogicalJoin(condition=[=($7, $11)], joinType=[left])
LogicalJoin(condition=[true], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(c=[$0], ck=[$0])
LogicalAggregate(group=[{}], c=[COUNT()])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalProject(DEPTNO=[$0], i=[true])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinSubQueryRemoveRuleWithQuantifierSome">
<Resource name="sql">
<![CDATA[SELECT empno FROM emp JOIN dept on emp.deptno >= SOME(SELECT deptno FROM dept)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalJoin(condition=[>= SOME($7, {
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
})], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$12], NAME=[$13])
LogicalJoin(condition=[CAST(OR(AND(IS TRUE(>=($7, $9)), <>($10, 0)), AND(>($10, $11), null, <>($10, 0), IS NOT TRUE(>=($7, $9))), AND(>=($7, $9), <>($10, 0), IS NOT TRUE(>=($7, $9)), <=($10, $11)))):BOOLEAN NOT NULL], joinType=[inner])
LogicalJoin(condition=[true], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(m=[$0], c=[$1], d=[$1])
LogicalAggregate(group=[{}], m=[MIN($0)], c=[COUNT()])
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testJoinSubQueryRewriteWithBothSidesColumns">
<Resource name="sql">
<![CDATA[SELECT empno FROM emp JOIN dept on emp.deptno + dept.deptno >= SOME(SELECT deptno FROM dept)]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalJoin(condition=[>= SOME(+($7, $9), {
LogicalProject(DEPTNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
})], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
Expand Down
72 changes: 72 additions & 0 deletions core/src/test/resources/sql/sub-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,78 @@ EnumerableCalc(expr#0..4=[{inputs}], EMPNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
!plan

# [CALCITE-6824] Subquery in join conditions rewrite fails if referencing a column from the right-hand side table
select empno from "scott".emp where (empno not in (select dept.deptno from dept))
in (select deptno = 0 from dept);
+-------+
| EMPNO |
+-------+
+-------+
(0 rows)

!ok
EnumerableCalc(expr#0..3=[{inputs}], EMPNO=[$t0])
EnumerableNestedLoopJoin(condition=[=(IS NULL($2), $3)], joinType=[inner])
EnumerableMergeJoin(condition=[=($0, $1)], joinType=[left])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableSort(sort0=[$0], dir0=[ASC])
EnumerableAggregate(group=[{0}], i=[LITERAL_AGG(true)])
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[CAST($t0):SMALLINT NOT NULL], DEPTNO=[$t3])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableAggregate(group=[{0}])
EnumerableCalc(expr#0..2=[{inputs}], expr#3=[CAST($t0):INTEGER NOT NULL], expr#4=[0], expr#5=[=($t3, $t4)], EXPR$0=[$t5])
EnumerableTableScan(table=[[scott, DEPT]])
!plan

# [CALCITE-6824] Subquery in join conditions rewrite fails if referencing a column from the right-hand side table
SELECT empno FROM emp JOIN dept on emp.deptno <= ALL(SELECT deptno FROM dept) and emp.deptno = dept.deptno;
+-------+
| EMPNO |
+-------+
| 7782 |
| 7839 |
| 7934 |
+-------+
(3 rows)

!ok
EnumerableCalc(expr#0..4=[{inputs}], EMPNO=[$t0])
EnumerableHashJoin(condition=[=($1, $5)], joinType=[semi])
EnumerableNestedLoopJoin(condition=[OR(=($3, 0), AND(<=($1, $2), IS NOT TRUE(OR(>($1, $2), >($3, $4)))))], joinType=[inner])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..1=[{inputs}], proj#0..1=[{exprs}], d=[$t1])
EnumerableAggregate(group=[{}], m=[MIN($0)], c=[COUNT()])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
!plan

# [CALCITE-6824] Subquery in join conditions rewrite fails if referencing a column from the right-hand side table
SELECT empno FROM emp JOIN dept on emp.deptno = (SELECT min(deptno) FROM dept) and emp.deptno = dept.deptno;
+-------+
| EMPNO |
+-------+
| 7782 |
| 7839 |
| 7934 |
+-------+
(3 rows)

!ok
EnumerableCalc(expr#0..2=[{inputs}], EMPNO=[$t0])
EnumerableHashJoin(condition=[=($2, $3)], joinType=[semi])
EnumerableCalc(expr#0..2=[{inputs}], EMPNO=[$t1], DEPTNO=[$t2], EXPR$0=[$t0])
EnumerableHashJoin(condition=[=($0, $2)], joinType=[inner])
EnumerableAggregate(group=[{}], EXPR$0=[MIN($0)])
EnumerableTableScan(table=[[scott, DEPT]])
EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
!plan

# Correlated NOT IN sub-query in WHERE clause of JOIN
select empno from "scott".emp as e
join "scott".dept as d using (deptno)
Expand Down

0 comments on commit 130c8dc

Please sign in to comment.