Skip to content

fix(isthmus): handle Subqueries/set predicates with field references outside of the subquery #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions isthmus/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,18 @@ plugins {
id("com.diffplug.spotless") version "6.19.0"
id("com.github.johnrengelman.shadow") version "8.1.1"
id("com.google.protobuf") version "0.9.4"
id("com.adarshr.test-logger") version "4.0.0"
signing
}

// Useful test logger when debugging, will output stdout/stderr to console
// saves time launching the HTML test reports
testlogger {
showStandardStreams = false
showPassedStandardStreams = true
showFailedStandardStreams = true
}

publishing {
publications {
create<MavenPublication>("maven-publish") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.SqlToSubstrait.EXTENSION_COLLECTION;
import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;

import com.google.common.collect.ImmutableList;
import io.substrait.expression.Expression;
Expand Down Expand Up @@ -44,13 +44,15 @@
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Holder;

/**
* RelVisitor to convert Substrait Rel plan to Calcite RelNode plan. Unsupported Rel node will call
Expand Down Expand Up @@ -139,8 +141,12 @@ public static RelNode convert(
@Override
public RelNode visit(Filter filter) throws RuntimeException {
RelNode input = filter.getInput().accept(this);
final Holder<RexCorrelVariable> v = Holder.empty();
expressionRexConverter.addCorrelVariable(v);

RelBuilder r1 = relBuilder.push(input).variable(v::set);
RexNode filterCondition = filter.getCondition().accept(expressionRexConverter);
RelNode node = relBuilder.push(input).filter(filterCondition).build();
RelNode node = r1.filter(List.of(v.get().id), filterCondition).build();
return applyRemap(node, filter.getRemap());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import io.substrait.type.Type;
import io.substrait.util.DecimalUtil;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
Expand All @@ -34,6 +36,7 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldCollation;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
Expand All @@ -46,6 +49,7 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.util.Holder;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;

Expand Down Expand Up @@ -468,7 +472,7 @@ public RexNode visit(Expression.InPredicate expr) throws RuntimeException {
return RexSubQuery.in(rel, ImmutableList.copyOf(needles));
}

static class ToRexWindowBound
static final class ToRexWindowBound
implements WindowBound.WindowBoundVisitor<RexWindowBound, RuntimeException> {

static RexWindowBound lowerBound(RexBuilder rexBuilder, WindowBound bound) {
Expand Down Expand Up @@ -538,22 +542,39 @@ public RexNode visit(Expression.Cast expr) throws RuntimeException {
@Override
public RexNode visit(FieldReference expr) throws RuntimeException {
if (expr.isSimpleRootReference()) {
Optional<Integer> outerref = expr.outerReferenceStepsOut();
var segment = expr.segments().get(0);
if (outerref.isPresent()) {
if (segment instanceof FieldReference.StructField) {
FieldReference.StructField f = (FieldReference.StructField) segment;
var node = referenceRelList.get(outerref.get() - 1).get();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to track or handle the fact the there might be multiple Filters that can add correlation variables? We only ever add to this list.

How would we know which ones came from which input?

This is one the issue I had in mind when I mentioned
https://github.com/substrait-io/substrait-java/pull/383/files#r2038429378

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly the Calcite docs don't really help with the scope of these variables. There is some concept of 'namespace' eg from this error .

All correlation variables should resolve to the same namespace. Prev ns=org.apache.calcite.sql.validate.IdentifierNamespace@d36c1c3, new ns=org.apache.calcite.sql.validate.IdentifierNamespace@96abc7

which came from

select
    c1.c_name,
    o1.o_orderstatus,
    o1.o_totalprice
from
    customer c1,
    orders o1
where
    o1.o_custkey = c1.c_custkey
    and o1.o_totalprice > (
        select
            avg(o_totalprice)
        from
            orders o2, customer c2
        where
            o2.o_totalprice < c1.c_acctbal
            and o2.o_totalprice > (
                select
                    avg(c3.c_acctbal)
                from
                    customer c3
                where
                    c3.c_custkey = o2.o_custkey
                    and c3.c_address = o1.o_clerk
            )
    );

change the last line to c3.c_address = o2.o_clerk and it's ok..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implications is that each relation and it's immediate subexpression is the same namespace.


RexInputRef rexInputRef;
if (segment instanceof FieldReference.StructField f) {
rexInputRef =
new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType()));
return rexBuilder.makeFieldAccess(node, f.offset());
}
} else {
throw new IllegalArgumentException("Unhandled type: " + segment);
RexInputRef rexInputRef;
if (segment instanceof FieldReference.StructField f) {
rexInputRef =
new RexInputRef(f.offset(), typeConverter.toCalcite(typeFactory, expr.getType()));
} else {
throw new IllegalArgumentException("Unhandled type: " + segment);
}
return rexInputRef;
}

return rexInputRef;
}

return visitFallback(expr);
}

protected List<Holder<RexCorrelVariable>> referenceRelList = new ArrayList<>();

public void addCorrelVariable(Holder<RexCorrelVariable> correlVaraible) {
referenceRelList.add(correlVaraible);
}

public Holder<RexCorrelVariable> getOuterRef(int i) {
return referenceRelList.get(i);
}

@Override
public RexNode visitFallback(Expression expr) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assumptions.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import io.substrait.plan.ProtoPlanConverter;
import io.substrait.proto.Plan;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.rel.RelNode;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/**
* Additional queries based around the style and schema of the tpc-h set Validating that the
* conversions can operate without exceptions
*/
@TestMethodOrder(OrderAnnotation.class)
@TestInstance(Lifecycle.PER_CLASS)
public class TestExtendedTpchQuery extends PlanTestBase {

private Map<Integer, Plan> allPlans = new HashMap<>();

// Keep list of the known test failures
// The `fromSubstrait` also assumes the to substrait worked as well
public static final List<Integer> toSubstraitKnownFails = List.of();
public static final List<Integer> fromSubstraitKnownFails = List.of();

@ParameterizedTest
@Order(1)
@ValueSource(ints = {1})
public void extendedTpchToSubstrait(int query) throws Exception {
assumeFalse(toSubstraitKnownFails.contains(query));

SqlToSubstrait s = new SqlToSubstrait();
String[] values = asString("tpch/schema.sql").split(";");
var creates =
Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
Plan protoPlan = s.execute(asString(String.format("tpch/extended/%02d.sql", query)), creates);

allPlans.put(query, protoPlan);
}

@ParameterizedTest
@Order(2)
@ValueSource(ints = {1})
public void extendedTpchFromSubstrait(int query) throws Exception {
assumeFalse(fromSubstraitKnownFails.contains(query));
assumeTrue(allPlans.containsKey(query));

Plan possible = allPlans.get(query);

io.substrait.plan.Plan plan = new ProtoPlanConverter().from(possible);
SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory);
RelNode relRoot = substraitToCalcite.convert(plan.getRoots().get(0)).project(true);
System.out.println(SubstraitToSql.toSql(relRoot));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: avoid printing during test. I suggest assertNotNull instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also changed the other existing tests where this occurred.

}
}
88 changes: 88 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/TestTpcdsQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assumptions.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import io.substrait.plan.ProtoPlanConverter;
import io.substrait.proto.Plan;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.adapter.tpcds.TpcdsSchema;
import org.apache.calcite.rel.RelNode;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/**
* Updated TPC-H test to convert SQL to Substrait and replay those plans back to SQL Validating that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you mean TPC-D here?

* the conversions can operate without exceptions
*/
@TestMethodOrder(OrderAnnotation.class)
@TestInstance(Lifecycle.PER_CLASS)
public class TestTpcdsQuery extends PlanTestBase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add these tests separately from your PR if possible, as it would be reasonable to have as many of these that already work running while we figure out the outer reference issues.


private List<Optional<Plan>> allPlans;

@BeforeAll
public void setup() {
allPlans = new ArrayList<Optional<Plan>>();
for (int i = 1; i < 101; i++) {
allPlans.add(Optional.empty());
}
}

// Keep list of the known test failures
// The `fromSubstrait` also assumes the to substrait worked as well
public static final List<Integer> toSubstraitKnownFails =
List.of(5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80, 84, 86, 89, 91, 98);
public static final List<Integer> fromSubstraitKnownFails = List.of(1, 8, 30, 49, 67, 81);

@ParameterizedTest
@Order(1)
@ValueSource(
ints = {
1, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30,
31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 54, 55, 56, 58,
59, 60, 61, 62, 64, 65, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87,
88, 90, 92, 93, 94, 95, 96, 97, 99, 2, 5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80,
84, 86, 89, 91, 98,
})
public void tpcdsSuccess(int query) throws Exception {
assumeFalse(toSubstraitKnownFails.contains(query));

SqlToSubstrait s = new SqlToSubstrait();
TpcdsSchema schema = new TpcdsSchema(1.0);
String sql = asString(String.format("tpcds/queries/%02d.sql", query));
Plan protoPlan = s.execute(sql, "tpcds", schema);
allPlans.set(query, Optional.of(protoPlan));
}

@ParameterizedTest
@Order(2)
@ValueSource(
ints = {
1, 3, 4, 6, 7, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 28, 29, 30,
31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 52, 54, 55, 56, 58,
59, 60, 61, 62, 64, 65, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 83, 85, 87,
88, 90, 92, 93, 94, 95, 96, 97, 99, 2, 5, 9, 12, 20, 27, 36, 47, 51, 53, 57, 63, 66, 70, 80,
84, 86, 89, 91, 98,
})
public void tpcdsFromSubstrait(int query) throws Exception {

assumeFalse(fromSubstraitKnownFails.contains(query));
assumeTrue(allPlans.get(query).isPresent());

Optional<Plan> possible = allPlans.get(query);

io.substrait.plan.Plan plan = new ProtoPlanConverter().from(possible.get());
SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory);
RelNode relRoot = substraitToCalcite.convert(plan.getRoots().get(0)).project(true);
System.out.println(SubstraitToSql.toSql(relRoot));
}
}
69 changes: 69 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/TestTpchQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assumptions.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import io.substrait.plan.ProtoPlanConverter;
import io.substrait.proto.Plan;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.rel.RelNode;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

/**
* Updated TPC-H test to convert SQL to Substrait and replay those plans back to SQL Validating that
* the conversions can operate without exceptions
*/
@TestMethodOrder(OrderAnnotation.class)
@TestInstance(Lifecycle.PER_CLASS)
public class TestTpchQuery extends PlanTestBase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add these tests separately from your PR if possible, as it would be reasonable to have as many of these that already work running while we figure out the outer reference issues.


private Map<Integer, Plan> allPlans = new HashMap<>();

// Keep list of the known test failures
// The `fromSubstrait` also assumes the to substrait worked as well
public static final List<Integer> toSubstraitKnownFails = List.of(22);
public static final List<Integer> fromSubstraitKnownFails = List.of(7, 8, 9);

@ParameterizedTest
@Order(1)
@ValueSource(
ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22})
public void tpchToSubstrait(int query) throws Exception {
assumeFalse(toSubstraitKnownFails.contains(query));

SqlToSubstrait s = new SqlToSubstrait();
String[] values = asString("tpch/schema.sql").split(";");
var creates =
Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
Plan protoPlan = s.execute(asString(String.format("tpch/queries/%02d.sql", query)), creates);

allPlans.put(query, protoPlan);
}

@ParameterizedTest
@Order(2)
@ValueSource(
ints = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22})
public void tpchFromSubstrait(int query) throws Exception {
assumeFalse(fromSubstraitKnownFails.contains(query));
assumeTrue(allPlans.containsKey(query));

Plan possible = allPlans.get(query);

io.substrait.plan.Plan plan = new ProtoPlanConverter().from(possible);
SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory);
RelNode relRoot = substraitToCalcite.convert(plan.getRoots().get(0)).project(true);
System.out.println(SubstraitToSql.toSql(relRoot));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import com.google.protobuf.util.JsonFormat;
import java.util.Arrays;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

@TestMethodOrder(OrderAnnotation.class)
public class TpchQueryNoValidation extends PlanTestBase {

@ParameterizedTest
Expand All @@ -18,7 +21,7 @@ public void tpch(int query) throws Exception {
Arrays.stream(values)
.filter(t -> !t.trim().isBlank())
.collect(java.util.stream.Collectors.toList());
var plan = s.execute(asString(String.format("tpch/queries/%02d.sql", query)), creates);
System.out.println(JsonFormat.printer().print(plan));
var protoPlan = s.execute(asString(String.format("tpch/queries/%02d.sql", query)), creates);
System.out.println(JsonFormat.printer().print(protoPlan));
}
}
Loading
Loading