From df079568ec38e7bad6097afab9690c279a635749 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Wed, 3 Dec 2025 15:44:50 +0100 Subject: [PATCH] style: add final modifier where possible Signed-off-by: Niels Pardon --- .../src/main/resources/substrait-pmd.xml | 3 + .../io/substrait/dsl/SubstraitBuilder.java | 484 +++++++++--------- .../expression/AbstractExpressionVisitor.java | 80 +-- .../java/io/substrait/expression/EnumArg.java | 9 +- .../io/substrait/expression/Expression.java | 126 ++--- .../expression/ExpressionCreator.java | 243 +++++---- .../substrait/expression/FieldReference.java | 87 ++-- .../io/substrait/expression/FunctionArg.java | 33 +- .../io/substrait/expression/WindowBound.java | 12 +- .../proto/ExpressionProtoConverter.java | 211 ++++---- .../proto/ProtoExpressionConverter.java | 128 ++--- .../ExtendedExpressionProtoConverter.java | 18 +- .../ProtoExtendedExpressionConverter.java | 27 +- .../extension/AbstractExtensionLookup.java | 20 +- .../java/io/substrait/extension/BidiMap.java | 18 +- .../extension/DefaultExtensionCatalog.java | 2 +- .../extension/ExtensionCollector.java | 60 +-- .../extension/ImmutableExtensionLookup.java | 80 +-- .../extension/ProtoExtensionConverter.java | 4 +- .../substrait/extension/SimpleExtension.java | 108 ++-- .../substrait/function/ParameterizedType.java | 6 +- .../function/ParameterizedTypeCreator.java | 31 +- .../function/ParameterizedTypeVisitor.java | 28 +- .../io/substrait/function/ToTypeString.java | 28 +- .../io/substrait/function/TypeExpression.java | 6 +- .../function/TypeExpressionCreator.java | 39 +- .../function/TypeExpressionVisitor.java | 34 +- .../src/main/java/io/substrait/hint/Hint.java | 4 +- .../io/substrait/plan/ProtoPlanConverter.java | 16 +- .../io/substrait/relation/AbstractDdlRel.java | 12 +- .../io/substrait/relation/AbstractRel.java | 2 +- .../relation/AbstractRelVisitor.java | 63 +-- .../substrait/relation/AbstractWriteRel.java | 18 +- .../java/io/substrait/relation/Aggregate.java | 2 +- .../AggregateFunctionProtoConverter.java | 10 +- .../relation/ConsistentPartitionWindow.java | 4 +- .../substrait/relation/CopyOnWriteUtils.java | 14 +- .../java/io/substrait/relation/Cross.java | 2 +- .../java/io/substrait/relation/EmptyScan.java | 2 +- .../java/io/substrait/relation/Expand.java | 8 +- .../ExpressionCopyOnWriteVisitor.java | 187 +++---- .../io/substrait/relation/ExtensionDdl.java | 2 +- .../io/substrait/relation/ExtensionLeaf.java | 4 +- .../io/substrait/relation/ExtensionMulti.java | 6 +- .../substrait/relation/ExtensionSingle.java | 5 +- .../io/substrait/relation/ExtensionTable.java | 4 +- .../io/substrait/relation/ExtensionWrite.java | 2 +- .../java/io/substrait/relation/Fetch.java | 2 +- .../java/io/substrait/relation/Filter.java | 2 +- .../main/java/io/substrait/relation/Join.java | 12 +- .../io/substrait/relation/LocalFiles.java | 2 +- .../java/io/substrait/relation/NamedDdl.java | 2 +- .../java/io/substrait/relation/NamedScan.java | 2 +- .../io/substrait/relation/NamedUpdate.java | 2 +- .../io/substrait/relation/NamedWrite.java | 2 +- .../java/io/substrait/relation/Project.java | 4 +- .../ProtoAggregateFunctionConverter.java | 20 +- .../substrait/relation/ProtoRelConverter.java | 372 +++++++------- .../main/java/io/substrait/relation/Rel.java | 8 +- .../relation/RelCopyOnWriteVisitor.java | 227 ++++---- .../substrait/relation/RelProtoConverter.java | 220 ++++---- .../main/java/io/substrait/relation/Set.java | 48 +- .../main/java/io/substrait/relation/Sort.java | 2 +- .../substrait/relation/VirtualTableScan.java | 30 +- .../relation/extensions/EmptyDetail.java | 8 +- .../substrait/relation/files/FileOrFiles.java | 9 +- .../relation/physical/BroadcastExchange.java | 2 +- .../substrait/relation/physical/HashJoin.java | 12 +- .../relation/physical/MergeJoin.java | 12 +- .../physical/MultiBucketExchange.java | 2 +- .../relation/physical/NestedLoopJoin.java | 12 +- .../relation/physical/RoundRobinExchange.java | 2 +- .../relation/physical/ScatterExchange.java | 2 +- .../physical/SingleBucketExchange.java | 2 +- .../java/io/substrait/type/Deserializers.java | 10 +- .../java/io/substrait/type/NamedStruct.java | 11 +- .../io/substrait/type/StringTypeVisitor.java | 58 +-- .../src/main/java/io/substrait/type/Type.java | 11 +- .../java/io/substrait/type/TypeCreator.java | 96 ++-- .../type/TypeExpressionEvaluator.java | 6 +- .../java/io/substrait/type/TypeVisitor.java | 2 +- .../main/java/io/substrait/type/YamlRead.java | 18 +- .../io/substrait/type/parser/ParseToPojo.java | 111 ++-- .../type/parser/TypeStringParser.java | 20 +- .../type/proto/BaseProtoConverter.java | 7 +- .../substrait/type/proto/BaseProtoTypes.java | 32 +- .../proto/ParameterizedProtoConverter.java | 58 ++- .../type/proto/ProtoTypeConverter.java | 16 +- .../proto/TypeExpressionProtoVisitor.java | 79 +-- .../type/proto/TypeProtoConverter.java | 32 +- .../java/io/substrait/util/DecimalUtil.java | 24 +- .../src/main/java/io/substrait/util/Util.java | 12 +- core/src/test/java/io/substrait/TestBase.java | 6 +- .../ExtendedExpressionRoundTripTest.java | 24 +- .../ExtensionCollectionMergeTest.java | 31 +- .../ExtensionCollectionUriUrnTest.java | 14 +- .../ExtensionCollectorUriUrnTest.java | 16 +- .../ImmutableExtensionLookupUriUrnTest.java | 216 ++++---- .../extension/TypeExtensionTest.java | 33 +- .../UriUrnMigrationEndToEndTest.java | 32 +- .../extension/UrnValidationTest.java | 15 +- .../substrait/relation/AggregateRelTest.java | 66 +-- .../relation/ProtoRelConverterTest.java | 79 +-- .../java/io/substrait/relation/SetTest.java | 14 +- .../substrait/relation/SpecVersionTest.java | 2 +- .../relation/VirtualTableScanTest.java | 8 +- .../substrait/type/parser/TestTypeParser.java | 9 +- .../type/proto/AggregateRoundtripTest.java | 26 +- ...istentPartitionWindowRelRoundtripTest.java | 22 +- .../type/proto/DdlRelRoundtripTest.java | 22 +- .../type/proto/ExchangeRelRoundtripTest.java | 23 +- .../type/proto/ExpandRelRoundtripTest.java | 8 +- .../type/proto/ExtensionRoundtripTest.java | 61 +-- .../proto/FieldReferenceRoundtripTest.java | 33 +- .../type/proto/FilterRelRoundtripTest.java | 51 +- .../type/proto/GenericRoundtripTest.java | 37 +- .../type/proto/IfThenRoundtripTest.java | 8 +- .../type/proto/JoinRoundtripTest.java | 16 +- .../type/proto/LiteralRoundtripTest.java | 6 +- .../type/proto/LocalFilesRoundtripTest.java | 16 +- .../type/proto/ProjectRelRoundtripTest.java | 34 +- .../type/proto/ReadRelRoundtripTest.java | 8 +- .../type/proto/SortRelRoundtripTest.java | 62 +-- .../type/proto/TestTypeRoundtrip.java | 14 +- .../type/proto/UpdateRelRoundtripTest.java | 10 +- .../type/proto/WriteRelRoundtripTest.java | 18 +- .../java/io/substrait/utils/StringHolder.java | 14 +- ...StringHolderHandlingProtoRelConverter.java | 14 +- .../main/java/io/substrait/examples/App.java | 8 +- .../examples/SparkConsumeSubstrait.java | 14 +- .../io/substrait/examples/SparkDataset.java | 20 +- .../io/substrait/examples/SparkHelper.java | 2 +- .../java/io/substrait/examples/SparkSQL.java | 22 +- .../examples/util/ExpressionStringify.java | 137 +++-- .../examples/util/FunctionArgStringify.java | 14 +- .../examples/util/ParentStringify.java | 10 +- .../examples/util/SubstraitStringify.java | 181 ++++--- .../examples/util/TypeStringify.java | 60 +-- .../isthmus/cli/IsthmusEntryPoint.java | 23 +- .../isthmus/cli/RegisterAtRuntime.java | 10 +- .../isthmus/cli/IsthmusEntryPointTest.java | 12 +- .../substrait/isthmus/AggregateFunctions.java | 16 +- .../isthmus/OuterReferenceResolver.java | 30 +- .../isthmus/PreCalciteAggregateValidator.java | 50 +- .../io/substrait/isthmus/RelNodeVisitor.java | 34 +- .../io/substrait/isthmus/SchemaCollector.java | 42 +- .../substrait/isthmus/SqlConverterBase.java | 6 +- .../isthmus/SqlExpressionToSubstrait.java | 88 ++-- .../io/substrait/isthmus/SqlKindFromRel.java | 77 +-- .../io/substrait/isthmus/SqlToSubstrait.java | 11 +- .../isthmus/SubstraitRelNodeConverter.java | 254 ++++----- .../isthmus/SubstraitRelVisitor.java | 230 +++++---- .../substrait/isthmus/SubstraitToCalcite.java | 73 +-- .../io/substrait/isthmus/SubstraitToSql.java | 2 +- .../io/substrait/isthmus/TypeConverter.java | 111 ++-- .../main/java/io/substrait/isthmus/Utils.java | 14 +- .../calcite/SubstraitOperatorTable.java | 10 +- .../isthmus/calcite/SubstraitSchema.java | 2 +- .../isthmus/calcite/SubstraitTable.java | 4 +- .../isthmus/calcite/rel/CreateTable.java | 4 +- .../isthmus/calcite/rel/CreateView.java | 4 +- .../calcite/rel/DdlSqlToRelConverter.java | 6 +- .../AggregateFunctionConverter.java | 47 +- .../isthmus/expression/CallConverters.java | 28 +- .../isthmus/expression/EnumConverter.java | 39 +- .../expression/ExpressionRexConverter.java | 255 +++++---- .../expression/FieldSelectionConverter.java | 24 +- .../isthmus/expression/FunctionConverter.java | 158 +++--- .../isthmus/expression/FunctionMappings.java | 8 +- .../IgnoreNullableAndParameters.java | 85 +-- .../expression/ListSqlOperatorFunctions.java | 8 +- .../LiteralConstructorConverter.java | 20 +- .../isthmus/expression/LiteralConverter.java | 62 +-- .../expression/RexExpressionConverter.java | 65 +-- .../expression/ScalarFunctionConverter.java | 53 +- .../expression/SortFieldConverter.java | 11 +- .../expression/SqrtFunctionMapper.java | 8 +- .../expression/TrimFunctionMapper.java | 34 +- .../expression/WindowBoundConverter.java | 8 +- .../expression/WindowFunctionConverter.java | 54 +- .../WindowRelFunctionConverter.java | 57 ++- .../isthmus/sql/SubstraitSqlDialect.java | 8 +- .../sql/SubstraitSqlStatementParser.java | 4 +- .../isthmus/sql/SubstraitSqlToCalcite.java | 52 +- .../isthmus/sql/SubstraitSqlValidator.java | 2 +- .../isthmus/AggregationFunctionsTest.java | 25 +- .../isthmus/ArithmeticFunctionTest.java | 120 ++--- .../io/substrait/isthmus/CalciteCallTest.java | 17 +- .../substrait/isthmus/CalciteLiteralTest.java | 96 ++-- .../io/substrait/isthmus/CalciteObjs.java | 10 +- .../io/substrait/isthmus/CalciteTypeTest.java | 87 ++-- .../isthmus/ComparisonFunctionsTest.java | 28 +- .../isthmus/ComplexAggregateTest.java | 38 +- .../io/substrait/isthmus/ComplexSortTest.java | 48 +- .../substrait/isthmus/CustomFunctionTest.java | 122 ++--- .../substrait/isthmus/DdlRoundtripTest.java | 4 +- .../isthmus/DuplicateFunctionUrnTest.java | 64 +-- .../isthmus/EmptyArrayLiteralTest.java | 12 +- .../isthmus/ExpressionConvertabilityTest.java | 31 +- .../isthmus/ExtendedExpressionTestBase.java | 27 +- .../java/io/substrait/isthmus/FetchTest.java | 6 +- .../isthmus/FunctionConversionTest.java | 50 +- .../substrait/isthmus/KeyConstraintsTest.java | 8 +- .../isthmus/LogarithmicFunctionTest.java | 12 +- .../substrait/isthmus/NameRoundtripTest.java | 21 +- .../isthmus/NestedStructQueryTest.java | 75 +-- .../isthmus/OptimizerIntegrationTest.java | 12 +- .../io/substrait/isthmus/PlanTestBase.java | 161 +++--- .../isthmus/ProtoPlanConverterTest.java | 30 +- .../isthmus/RelCopyOnWriteVisitorTest.java | 63 +-- .../java/io/substrait/isthmus/RelCreator.java | 29 +- .../isthmus/RelExtensionRoundtripTest.java | 117 +++-- .../isthmus/RoundingFunctionTest.java | 12 +- .../isthmus/SchemaCollectorTest.java | 56 +- .../SimpleExtendedExpressionsTest.java | 8 +- .../substrait/isthmus/StringFunctionTest.java | 144 +++--- .../substrait/isthmus/SubqueryPlanTest.java | 90 ++-- .../substrait/isthmus/Substrait2SqlTest.java | 20 +- .../SubstraitExpressionConverterTest.java | 62 +-- .../SubstraitRelNodeConverterTest.java | 98 ++-- .../isthmus/SubstraitToCalciteTest.java | 40 +- .../SubtraitRelVisitorExtensionTest.java | 8 +- .../io/substrait/isthmus/TpcdsQueryTest.java | 12 +- .../io/substrait/isthmus/TpchQueryTest.java | 10 +- .../substrait/isthmus/WindowFunctionTest.java | 41 +- .../AggregateFunctionConverterTest.java | 4 +- .../io/substrait/isthmus/utils/SetUtils.java | 8 +- .../isthmus/utils/UserTypeFactory.java | 12 +- 228 files changed, 5009 insertions(+), 4583 deletions(-) diff --git a/build-logic/src/main/resources/substrait-pmd.xml b/build-logic/src/main/resources/substrait-pmd.xml index 5282c3a06..fd33f3eb9 100644 --- a/build-logic/src/main/resources/substrait-pmd.xml +++ b/build-logic/src/main/resources/substrait-pmd.xml @@ -12,6 +12,9 @@ + + + diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 84af839d8..3a7035bc3 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -54,45 +54,46 @@ public class SubstraitBuilder { private final SimpleExtension.ExtensionCollection extensions; - public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) { + public SubstraitBuilder(final SimpleExtension.ExtensionCollection extensions) { this.extensions = extensions; } // Relations - public Aggregate.Measure measure(AggregateFunctionInvocation aggFn) { + public Aggregate.Measure measure(final AggregateFunctionInvocation aggFn) { return Aggregate.Measure.builder().function(aggFn).build(); } - public Aggregate.Measure measure(AggregateFunctionInvocation aggFn, Expression preMeasureFilter) { + public Aggregate.Measure measure( + final AggregateFunctionInvocation aggFn, final Expression preMeasureFilter) { return Aggregate.Measure.builder().function(aggFn).preMeasureFilter(preMeasureFilter).build(); } public Aggregate aggregate( - Function groupingFn, - Function> measuresFn, - Rel input) { - Function> groupingsFn = + final Function groupingFn, + final Function> measuresFn, + final Rel input) { + final Function> groupingsFn = groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList())); return aggregate(groupingsFn, measuresFn, Optional.empty(), input); } public Aggregate aggregate( - Function groupingFn, - Function> measuresFn, - Rel.Remap remap, - Rel input) { - Function> groupingsFn = + final Function groupingFn, + final Function> measuresFn, + final Rel.Remap remap, + final Rel input) { + final Function> groupingsFn = groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList())); return aggregate(groupingsFn, measuresFn, Optional.of(remap), input); } private Aggregate aggregate( - Function> groupingsFn, - Function> measuresFn, - Optional remap, - Rel input) { - List groupings = groupingsFn.apply(input); - List measures = measuresFn.apply(input); + final Function> groupingsFn, + final Function> measuresFn, + final Optional remap, + final Rel input) { + final List groupings = groupingsFn.apply(input); + final List measures = measuresFn.apply(input); return Aggregate.builder() .groupings(groupings) .measures(measures) @@ -101,57 +102,64 @@ private Aggregate aggregate( .build(); } - public Cross cross(Rel left, Rel right) { + public Cross cross(final Rel left, final Rel right) { return cross(left, right, Optional.empty()); } - public Cross cross(Rel left, Rel right, Rel.Remap remap) { + public Cross cross(final Rel left, final Rel right, final Rel.Remap remap) { return cross(left, right, Optional.of(remap)); } - private Cross cross(Rel left, Rel right, Optional remap) { + private Cross cross(final Rel left, final Rel right, final Optional remap) { return Cross.builder().left(left).right(right).remap(remap).build(); } - public Fetch fetch(long offset, long count, Rel input) { + public Fetch fetch(final long offset, final long count, final Rel input) { return fetch(offset, OptionalLong.of(count), Optional.empty(), input); } - public Fetch fetch(long offset, long count, Rel.Remap remap, Rel input) { + public Fetch fetch(final long offset, final long count, final Rel.Remap remap, final Rel input) { return fetch(offset, OptionalLong.of(count), Optional.of(remap), input); } - public Fetch limit(long limit, Rel input) { + public Fetch limit(final long limit, final Rel input) { return fetch(0, OptionalLong.of(limit), Optional.empty(), input); } - public Fetch limit(long limit, Rel.Remap remap, Rel input) { + public Fetch limit(final long limit, final Rel.Remap remap, final Rel input) { return fetch(0, OptionalLong.of(limit), Optional.of(remap), input); } - public Fetch offset(long offset, Rel input) { + public Fetch offset(final long offset, final Rel input) { return fetch(offset, OptionalLong.empty(), Optional.empty(), input); } - public Fetch offset(long offset, Rel.Remap remap, Rel input) { + public Fetch offset(final long offset, final Rel.Remap remap, final Rel input) { return fetch(offset, OptionalLong.empty(), Optional.of(remap), input); } - private Fetch fetch(long offset, OptionalLong count, Optional remap, Rel input) { + private Fetch fetch( + final long offset, + final OptionalLong count, + final Optional remap, + final Rel input) { return Fetch.builder().offset(offset).count(count).input(input).remap(remap).build(); } - public Filter filter(Function conditionFn, Rel input) { + public Filter filter(final Function conditionFn, final Rel input) { return filter(conditionFn, Optional.empty(), input); } - public Filter filter(Function conditionFn, Rel.Remap remap, Rel input) { + public Filter filter( + final Function conditionFn, final Rel.Remap remap, final Rel input) { return filter(conditionFn, Optional.of(remap), input); } private Filter filter( - Function conditionFn, Optional remap, Rel input) { - Expression condition = conditionFn.apply(input); + final Function conditionFn, + final Optional remap, + final Rel input) { + final Expression condition = conditionFn.apply(input); return Filter.builder().input(input).condition(condition).remap(remap).build(); } @@ -159,7 +167,7 @@ public static final class JoinInput { private final Rel left; private final Rel right; - JoinInput(Rel left, Rel right) { + JoinInput(final Rel left, final Rel right) { this.left = left; this.right = right; } @@ -173,36 +181,43 @@ public Rel right() { } } - public Join innerJoin(Function conditionFn, Rel left, Rel right) { + public Join innerJoin( + final Function conditionFn, final Rel left, final Rel right) { return join(conditionFn, Join.JoinType.INNER, left, right); } public Join innerJoin( - Function conditionFn, Rel.Remap remap, Rel left, Rel right) { + final Function conditionFn, + final Rel.Remap remap, + final Rel left, + final Rel right) { return join(conditionFn, Join.JoinType.INNER, remap, left, right); } public Join join( - Function conditionFn, Join.JoinType joinType, Rel left, Rel right) { + final Function conditionFn, + final Join.JoinType joinType, + final Rel left, + final Rel right) { return join(conditionFn, joinType, Optional.empty(), left, right); } public Join join( - Function conditionFn, - Join.JoinType joinType, - Rel.Remap remap, - Rel left, - Rel right) { + final Function conditionFn, + final Join.JoinType joinType, + final Rel.Remap remap, + final Rel left, + final Rel right) { return join(conditionFn, joinType, Optional.of(remap), left, right); } private Join join( - Function conditionFn, - Join.JoinType joinType, - Optional remap, - Rel left, - Rel right) { - Expression condition = conditionFn.apply(new JoinInput(left, right)); + final Function conditionFn, + final Join.JoinType joinType, + final Optional remap, + final Rel left, + final Rel right) { + final Expression condition = conditionFn.apply(new JoinInput(left, right)); return Join.builder() .left(left) .right(right) @@ -213,21 +228,21 @@ private Join join( } public HashJoin hashJoin( - List leftKeys, - List rightKeys, - HashJoin.JoinType joinType, - Rel left, - Rel right) { + final List leftKeys, + final List rightKeys, + final HashJoin.JoinType joinType, + final Rel left, + final Rel right) { return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right); } public HashJoin hashJoin( - List leftKeys, - List rightKeys, - HashJoin.JoinType joinType, - Optional remap, - Rel left, - Rel right) { + final List leftKeys, + final List rightKeys, + final HashJoin.JoinType joinType, + final Optional remap, + final Rel left, + final Rel right) { return HashJoin.builder() .left(left) .right(right) @@ -241,21 +256,21 @@ public HashJoin hashJoin( } public MergeJoin mergeJoin( - List leftKeys, - List rightKeys, - MergeJoin.JoinType joinType, - Rel left, - Rel right) { + final List leftKeys, + final List rightKeys, + final MergeJoin.JoinType joinType, + final Rel left, + final Rel right) { return mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right); } public MergeJoin mergeJoin( - List leftKeys, - List rightKeys, - MergeJoin.JoinType joinType, - Optional remap, - Rel left, - Rel right) { + final List leftKeys, + final List rightKeys, + final MergeJoin.JoinType joinType, + final Optional remap, + final Rel left, + final Rel right) { return MergeJoin.builder() .left(left) .right(right) @@ -269,20 +284,20 @@ public MergeJoin mergeJoin( } public NestedLoopJoin nestedLoopJoin( - Function conditionFn, - NestedLoopJoin.JoinType joinType, - Rel left, - Rel right) { + final Function conditionFn, + final NestedLoopJoin.JoinType joinType, + final Rel left, + final Rel right) { return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right); } private NestedLoopJoin nestedLoopJoin( - Function conditionFn, - NestedLoopJoin.JoinType joinType, - Optional remap, - Rel left, - Rel right) { - Expression condition = conditionFn.apply(new JoinInput(left, right)); + final Function conditionFn, + final NestedLoopJoin.JoinType joinType, + final Optional remap, + final Rel left, + final Rel right) { + final Expression condition = conditionFn.apply(new JoinInput(left, right)); return NestedLoopJoin.builder() .left(left) .right(right) @@ -293,25 +308,27 @@ private NestedLoopJoin nestedLoopJoin( } public NamedScan namedScan( - Iterable tableName, Iterable columnNames, Iterable types) { + final Iterable tableName, + final Iterable columnNames, + final Iterable types) { return namedScan(tableName, columnNames, types, Optional.empty()); } public NamedScan namedScan( - Iterable tableName, - Iterable columnNames, - Iterable types, - Rel.Remap remap) { + final Iterable tableName, + final Iterable columnNames, + final Iterable types, + final Rel.Remap remap) { return namedScan(tableName, columnNames, types, Optional.of(remap)); } private NamedScan namedScan( - Iterable tableName, - Iterable columnNames, - Iterable types, - Optional remap) { - Type.Struct struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); - NamedStruct namedStruct = NamedStruct.of(columnNames, struct); + final Iterable tableName, + final Iterable columnNames, + final Iterable types, + final Optional remap) { + final Type.Struct struct = Type.Struct.builder().addAllFields(types).nullable(false).build(); + final NamedStruct namedStruct = NamedStruct.of(columnNames, struct); return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build(); } @@ -322,37 +339,37 @@ public EmptyScan emptyScan() { } public NamedWrite namedWrite( - Iterable tableName, - Iterable columnNames, - AbstractWriteRel.WriteOp op, - AbstractWriteRel.CreateMode createMode, - AbstractWriteRel.OutputMode outputMode, - Rel input) { + final Iterable tableName, + final Iterable columnNames, + final AbstractWriteRel.WriteOp op, + final AbstractWriteRel.CreateMode createMode, + final AbstractWriteRel.OutputMode outputMode, + final Rel input) { return namedWrite(tableName, columnNames, op, createMode, outputMode, input, Optional.empty()); } public NamedWrite namedWrite( - Iterable tableName, - Iterable columnNames, - AbstractWriteRel.WriteOp op, - AbstractWriteRel.CreateMode createMode, - AbstractWriteRel.OutputMode outputMode, - Rel input, - Rel.Remap remap) { + final Iterable tableName, + final Iterable columnNames, + final AbstractWriteRel.WriteOp op, + final AbstractWriteRel.CreateMode createMode, + final AbstractWriteRel.OutputMode outputMode, + final Rel input, + final Rel.Remap remap) { return namedWrite( tableName, columnNames, op, createMode, outputMode, input, Optional.of(remap)); } private NamedWrite namedWrite( - Iterable tableName, - Iterable columnNames, - AbstractWriteRel.WriteOp op, - AbstractWriteRel.CreateMode createMode, - AbstractWriteRel.OutputMode outputMode, - Rel input, - Optional remap) { - Type.Struct struct = input.getRecordType(); - NamedStruct namedStruct = NamedStruct.of(columnNames, struct); + final Iterable tableName, + final Iterable columnNames, + final AbstractWriteRel.WriteOp op, + final AbstractWriteRel.CreateMode createMode, + final AbstractWriteRel.OutputMode outputMode, + final Rel input, + final Optional remap) { + final Type.Struct struct = input.getRecordType(); + final NamedStruct namedStruct = NamedStruct.of(columnNames, struct); return NamedWrite.builder() .names(tableName) .tableSchema(namedStruct) @@ -365,39 +382,39 @@ private NamedWrite namedWrite( } public NamedUpdate namedUpdate( - Iterable tableName, - Iterable columnNames, - List transformations, - Expression condition, - boolean nullable) { + final Iterable tableName, + final Iterable columnNames, + final List transformations, + final Expression condition, + final boolean nullable) { return namedUpdate( tableName, columnNames, transformations, condition, nullable, Optional.empty()); } public NamedUpdate namedUpdate( - Iterable tableName, - Iterable columnNames, - List transformations, - Expression condition, - boolean nullable, - Rel.Remap remap) { + final Iterable tableName, + final Iterable columnNames, + final List transformations, + final Expression condition, + final boolean nullable, + final Rel.Remap remap) { return namedUpdate( tableName, columnNames, transformations, condition, nullable, Optional.of(remap)); } private NamedUpdate namedUpdate( - Iterable tableName, - Iterable columnNames, - List transformations, - Expression condition, - boolean nullable, - Optional remap) { - List types = + final Iterable tableName, + final Iterable columnNames, + final List transformations, + final Expression condition, + final boolean nullable, + final Optional remap) { + final List types = transformations.stream() .map(t -> t.getTransformation().getType()) .collect(Collectors.toList()); - Type.Struct struct = Type.Struct.builder().fields(types).nullable(nullable).build(); - NamedStruct namedStruct = NamedStruct.of(columnNames, struct); + final Type.Struct struct = Type.Struct.builder().fields(types).nullable(nullable).build(); + final NamedStruct namedStruct = NamedStruct.of(columnNames, struct); return NamedUpdate.builder() .names(tableName) .tableSchema(namedStruct) @@ -407,90 +424,97 @@ private NamedUpdate namedUpdate( .build(); } - public Project project(Function> expressionsFn, Rel input) { + public Project project( + final Function> expressionsFn, final Rel input) { return project(expressionsFn, Optional.empty(), input); } public Project project( - Function> expressionsFn, Rel.Remap remap, Rel input) { + final Function> expressionsFn, + final Rel.Remap remap, + final Rel input) { return project(expressionsFn, Optional.of(remap), input); } private Project project( - Function> expressionsFn, - Optional remap, - Rel input) { - Iterable expressions = expressionsFn.apply(input); + final Function> expressionsFn, + final Optional remap, + final Rel input) { + final Iterable expressions = expressionsFn.apply(input); return Project.builder().input(input).expressions(expressions).remap(remap).build(); } - public Expand expand(Function> fieldsFn, Rel input) { + public Expand expand( + final Function> fieldsFn, final Rel input) { return expand(fieldsFn, Optional.empty(), input); } public Expand expand( - Function> fieldsFn, Rel.Remap remap, Rel input) { + final Function> fieldsFn, + final Rel.Remap remap, + final Rel input) { return expand(fieldsFn, Optional.of(remap), input); } private Expand expand( - Function> fieldsFn, - Optional remap, - Rel input) { - Iterable fields = fieldsFn.apply(input); + final Function> fieldsFn, + final Optional remap, + final Rel input) { + final Iterable fields = fieldsFn.apply(input); return Expand.builder().input(input).fields(fields).remap(remap).build(); } - public Set set(Set.SetOp op, Rel... inputs) { + public Set set(final Set.SetOp op, final Rel... inputs) { return set(op, Optional.empty(), inputs); } - public Set set(Set.SetOp op, Rel.Remap remap, Rel... inputs) { + public Set set(final Set.SetOp op, final Rel.Remap remap, final Rel... inputs) { return set(op, Optional.of(remap), inputs); } - private Set set(Set.SetOp op, Optional remap, Rel... inputs) { + private Set set(final Set.SetOp op, final Optional remap, final Rel... inputs) { return Set.builder().setOp(op).remap(remap).addAllInputs(Arrays.asList(inputs)).build(); } - public Sort sort(Function> sortFieldFn, Rel input) { + public Sort sort( + final Function> sortFieldFn, final Rel input) { return sort(sortFieldFn, Optional.empty(), input); } public Sort sort( - Function> sortFieldFn, - Rel.Remap remap, - Rel input) { + final Function> sortFieldFn, + final Rel.Remap remap, + final Rel input) { return sort(sortFieldFn, Optional.of(remap), input); } private Sort sort( - Function> sortFieldFn, - Optional remap, - Rel input) { - Iterable condition = sortFieldFn.apply(input); + final Function> sortFieldFn, + final Optional remap, + final Rel input) { + final Iterable condition = sortFieldFn.apply(input); return Sort.builder().input(input).sortFields(condition).remap(remap).build(); } // Expressions - public Expression.BoolLiteral bool(boolean v) { + public Expression.BoolLiteral bool(final boolean v) { return Expression.BoolLiteral.builder().value(v).build(); } - public Expression.I32Literal i32(int v) { + public Expression.I32Literal i32(final int v) { return Expression.I32Literal.builder().value(v).build(); } - public Expression.FP64Literal fp64(double v) { + public Expression.FP64Literal fp64(final double v) { return Expression.FP64Literal.builder().value(v).build(); } - public Expression.StrLiteral str(String s) { + public Expression.StrLiteral str(final String s) { return Expression.StrLiteral.builder().value(s).build(); } - public Expression cast(Expression input, Type type) { + public Expression cast(final Expression input, final Type type) { return Cast.builder() .input(input) .type(type) @@ -498,46 +522,46 @@ public Expression cast(Expression input, Type type) { .build(); } - public FieldReference fieldReference(Rel input, int index) { + public FieldReference fieldReference(final Rel input, final int index) { return FieldReference.newInputRelReference(index, input); } - public List fieldReferences(Rel input, int... indexes) { + public List fieldReferences(final Rel input, final int... indexes) { return Arrays.stream(indexes) .mapToObj(index -> fieldReference(input, index)) .collect(java.util.stream.Collectors.toList()); } - public FieldReference fieldReference(List inputs, int index) { + public FieldReference fieldReference(final List inputs, final int index) { return FieldReference.newInputRelReference(index, inputs); } - public List fieldReferences(List inputs, int... indexes) { + public List fieldReferences(final List inputs, final int... indexes) { return Arrays.stream(indexes) .mapToObj(index -> fieldReference(inputs, index)) .collect(java.util.stream.Collectors.toList()); } - public IfThen ifThen(Iterable ifClauses, Expression elseClause) { + public IfThen ifThen(final Iterable ifClauses, final Expression elseClause) { return IfThen.builder().addAllIfClauses(ifClauses).elseClause(elseClause).build(); } - public IfClause ifClause(Expression condition, Expression then) { + public IfClause ifClause(final Expression condition, final Expression then) { return IfClause.builder().condition(condition).then(then).build(); } - public Expression singleOrList(Expression condition, Expression... options) { + public Expression singleOrList(final Expression condition, final Expression... options) { return SingleOrList.builder().condition(condition).addOptions(options).build(); } - public Expression.InPredicate inPredicate(Rel haystack, Expression... needles) { + public Expression.InPredicate inPredicate(final Rel haystack, final Expression... needles) { return Expression.InPredicate.builder() .addAllNeedles(Arrays.asList(needles)) .haystack(haystack) .build(); } - public List sortFields(Rel input, int... indexes) { + public List sortFields(final Rel input, final int... indexes) { return Arrays.stream(indexes) .mapToObj( index -> @@ -549,16 +573,18 @@ public List sortFields(Rel input, int... indexes) { } public Expression.SortField sortField( - Expression expression, Expression.SortDirection sortDirection) { + final Expression expression, final Expression.SortDirection sortDirection) { return Expression.SortField.builder().expr(expression).direction(sortDirection).build(); } - public SwitchClause switchClause(Expression.Literal condition, Expression then) { + public SwitchClause switchClause(final Expression.Literal condition, final Expression then) { return SwitchClause.builder().condition(condition).then(then).build(); } public Switch switchExpression( - Expression match, Iterable clauses, Expression defaultClause) { + final Expression match, + final Iterable clauses, + final Expression defaultClause) { return Switch.builder() .match(match) .addAllSwitchClauses(clauses) @@ -569,8 +595,8 @@ public Switch switchExpression( // Aggregate Functions public AggregateFunctionInvocation aggregateFn( - String urn, String key, Type outputType, Expression... args) { - SimpleExtension.AggregateFunctionVariant declaration = + final String urn, final String key, final Type outputType, final Expression... args) { + final SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(urn, key)); return AggregateFunctionInvocation.builder() .arguments(Arrays.stream(args).collect(java.util.stream.Collectors.toList())) @@ -581,17 +607,17 @@ public AggregateFunctionInvocation aggregateFn( .build(); } - public Aggregate.Grouping grouping(Rel input, int... indexes) { - List columns = fieldReferences(input, indexes); + public Aggregate.Grouping grouping(final Rel input, final int... indexes) { + final List columns = fieldReferences(input, indexes); return Aggregate.Grouping.builder().addAllExpressions(columns).build(); } - public Aggregate.Grouping grouping(Expression... expressions) { + public Aggregate.Grouping grouping(final Expression... expressions) { return Aggregate.Grouping.builder().addExpressions(expressions).build(); } - public Aggregate.Measure count(Rel input, int field) { - SimpleExtension.AggregateFunctionVariant declaration = + public Aggregate.Measure count(final Rel input, final int field) { + final SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any")); @@ -624,11 +650,11 @@ public Measure countStar() { .build()); } - public Aggregate.Measure min(Rel input, int field) { + public Aggregate.Measure min(final Rel input, final int field) { return min(fieldReference(input, field)); } - public Aggregate.Measure min(Expression expr) { + public Aggregate.Measure min(final Expression expr) { return singleArgumentArithmeticAggregate( expr, "min", @@ -636,11 +662,11 @@ public Aggregate.Measure min(Expression expr) { TypeCreator.asNullable(expr.getType())); } - public Aggregate.Measure max(Rel input, int field) { + public Aggregate.Measure max(final Rel input, final int field) { return max(fieldReference(input, field)); } - public Aggregate.Measure max(Expression expr) { + public Aggregate.Measure max(final Expression expr) { return singleArgumentArithmeticAggregate( expr, "max", @@ -648,11 +674,11 @@ public Aggregate.Measure max(Expression expr) { TypeCreator.asNullable(expr.getType())); } - public Aggregate.Measure avg(Rel input, int field) { + public Aggregate.Measure avg(final Rel input, final int field) { return avg(fieldReference(input, field)); } - public Aggregate.Measure avg(Expression expr) { + public Aggregate.Measure avg(final Expression expr) { return singleArgumentArithmeticAggregate( expr, "avg", @@ -660,11 +686,11 @@ public Aggregate.Measure avg(Expression expr) { TypeCreator.asNullable(expr.getType())); } - public Aggregate.Measure sum(Rel input, int field) { + public Aggregate.Measure sum(final Rel input, final int field) { return sum(fieldReference(input, field)); } - public Aggregate.Measure sum(Expression expr) { + public Aggregate.Measure sum(final Expression expr) { return singleArgumentArithmeticAggregate( expr, "sum", @@ -672,11 +698,11 @@ public Aggregate.Measure sum(Expression expr) { TypeCreator.asNullable(expr.getType())); } - public Aggregate.Measure sum0(Rel input, int field) { + public Aggregate.Measure sum0(final Rel input, final int field) { return sum(fieldReference(input, field)); } - public Aggregate.Measure sum0(Expression expr) { + public Aggregate.Measure sum0(final Expression expr) { return singleArgumentArithmeticAggregate( expr, "sum0", @@ -685,9 +711,9 @@ public Aggregate.Measure sum0(Expression expr) { } private Aggregate.Measure singleArgumentArithmeticAggregate( - Expression expr, String functionName, Type outputType) { - String typeString = ToTypeString.apply(expr.getType()); - SimpleExtension.AggregateFunctionVariant declaration = + final Expression expr, final String functionName, final Type outputType) { + final String typeString = ToTypeString.apply(expr.getType()); + final SimpleExtension.AggregateFunctionVariant declaration = extensions.getAggregateFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, @@ -707,9 +733,9 @@ private Aggregate.Measure singleArgumentArithmeticAggregate( // Scalar Functions - public Expression.ScalarFunctionInvocation negate(Expression expr) { + public Expression.ScalarFunctionInvocation negate(final Expression expr) { // output type of negate is the same as the input type - Type outputType = expr.getType(); + final Type outputType = expr.getType(); return scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, String.format("negate:%s", ToTypeString.apply(outputType)), @@ -717,29 +743,31 @@ public Expression.ScalarFunctionInvocation negate(Expression expr) { expr); } - public Expression.ScalarFunctionInvocation add(Expression left, Expression right) { + public Expression.ScalarFunctionInvocation add(final Expression left, final Expression right) { return arithmeticFunction("add", left, right); } - public Expression.ScalarFunctionInvocation subtract(Expression left, Expression right) { + public Expression.ScalarFunctionInvocation subtract( + final Expression left, final Expression right) { return arithmeticFunction("substract", left, right); } - public Expression.ScalarFunctionInvocation multiply(Expression left, Expression right) { + public Expression.ScalarFunctionInvocation multiply( + final Expression left, final Expression right) { return arithmeticFunction("multiply", left, right); } - public Expression.ScalarFunctionInvocation divide(Expression left, Expression right) { + public Expression.ScalarFunctionInvocation divide(final Expression left, final Expression right) { return arithmeticFunction("divide", left, right); } private Expression.ScalarFunctionInvocation arithmeticFunction( - String fname, Expression left, Expression right) { - String leftTypeStr = ToTypeString.apply(left.getType()); - String rightTypeStr = ToTypeString.apply(right.getType()); - String key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); + final String fname, final Expression left, final Expression right) { + final String leftTypeStr = ToTypeString.apply(left.getType()); + final String rightTypeStr = ToTypeString.apply(right.getType()); + final String key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr); - boolean isOutputNullable = left.getType().nullable() || right.getType().nullable(); + final boolean isOutputNullable = left.getType().nullable() || right.getType().nullable(); Type outputType = left.getType(); outputType = isOutputNullable @@ -749,30 +777,30 @@ private Expression.ScalarFunctionInvocation arithmeticFunction( return scalarFn(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, key, outputType, left, right); } - public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) { + public Expression.ScalarFunctionInvocation equal(final Expression left, final Expression right) { return scalarFn( DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right); } - public Expression.ScalarFunctionInvocation and(Expression... args) { + public Expression.ScalarFunctionInvocation and(final Expression... args) { // If any arg is nullable, the output of and is potentially nullable // For example: false and null = null - boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); - Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; + final boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); + final Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "and:bool", outputType, args); } - public Expression.ScalarFunctionInvocation or(Expression... args) { + public Expression.ScalarFunctionInvocation or(final Expression... args) { // If any arg is nullable, the output of or is potentially nullable // For example: false or null = null - boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); - Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; + final boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable()); + final Type outputType = isOutputNullable ? N.BOOLEAN : R.BOOLEAN; return scalarFn(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, "or:bool", outputType, args); } public Expression.ScalarFunctionInvocation scalarFn( - String urn, String key, Type outputType, FunctionArg... args) { - SimpleExtension.ScalarFunctionVariant declaration = + final String urn, final String key, final Type outputType, final FunctionArg... args) { + final SimpleExtension.ScalarFunctionVariant declaration = extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(urn, key)); return Expression.ScalarFunctionInvocation.builder() .declaration(declaration) @@ -782,16 +810,16 @@ public Expression.ScalarFunctionInvocation scalarFn( } public Expression.WindowFunctionInvocation windowFn( - String urn, - String key, - Type outputType, - Expression.AggregationPhase aggregationPhase, - Expression.AggregationInvocation invocation, - Expression.WindowBoundsType boundsType, - WindowBound lowerBound, - WindowBound upperBound, - Expression... args) { - SimpleExtension.WindowFunctionVariant declaration = + final String urn, + final String key, + final Type outputType, + final Expression.AggregationPhase aggregationPhase, + final Expression.AggregationInvocation invocation, + final Expression.WindowBoundsType boundsType, + final WindowBound lowerBound, + final WindowBound upperBound, + final Expression... args) { + final SimpleExtension.WindowFunctionVariant declaration = extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(urn, key)); return Expression.WindowFunctionInvocation.builder() .declaration(declaration) @@ -807,29 +835,29 @@ public Expression.WindowFunctionInvocation windowFn( // Types - public Type.UserDefined userDefinedType(String urn, String typeName) { + public Type.UserDefined userDefinedType(final String urn, final String typeName) { return Type.UserDefined.builder().urn(urn).name(typeName).nullable(false).build(); } // Misc - public Plan.Root root(Rel rel) { + public Plan.Root root(final Rel rel) { return Plan.Root.builder().input(rel).build(); } - public Plan plan(Plan.Root root) { + public Plan plan(final Plan.Root root) { return Plan.builder().addRoots(root).build(); } - public Rel.Remap remap(Integer... fields) { + public Rel.Remap remap(final Integer... fields) { return Rel.Remap.of(Arrays.asList(fields)); } - public Expression scalarSubquery(Rel input, Type type) { + public Expression scalarSubquery(final Rel input, final Type type) { return Expression.ScalarSubquery.builder().input(input).type(type).build(); } - public Expression exists(Rel rel) { + public Expression exists(final Rel rel) { return Expression.SetPredicate.builder() .tuples(rel) .predicateOp(PredicateOp.PREDICATE_OP_EXISTS) diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 072507295..138963816 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -7,202 +7,202 @@ public abstract class AbstractExpressionVisitor R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + final SimpleExtension.Function fnDef, + final int argIdx, + final FuncArgVisitor fnArgVisitor, + final C context) throws E { return fnArgVisitor.visitEnumArg(fnDef, argIdx, this, context); } - static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) { + static EnumArg of(final SimpleExtension.EnumArgument enumArg, final String option) { assert (enumArg.options().contains(option)); return builder().value(Optional.of(option)).build(); } - static EnumArg of(String value) { + static EnumArg of(final String value) { return builder().value(Optional.of(value)).build(); } diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..c9750c602 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -20,7 +20,10 @@ public interface Expression extends FunctionArg { @Override default R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + final SimpleExtension.Function fnDef, + final int argIdx, + final FuncArgVisitor fnArgVisitor, + final C context) throws E { return fnArgVisitor.visitExpr(fnDef, argIdx, this, context); } @@ -50,7 +53,7 @@ public static ImmutableExpression.NullLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -70,7 +73,7 @@ public static ImmutableExpression.BoolLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -90,7 +93,7 @@ public static ImmutableExpression.I8Literal.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -110,7 +113,7 @@ public static ImmutableExpression.I16Literal.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -130,7 +133,7 @@ public static ImmutableExpression.I32Literal.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -150,7 +153,7 @@ public static ImmutableExpression.I64Literal.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -170,7 +173,7 @@ public static ImmutableExpression.FP32Literal.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -190,7 +193,7 @@ public static ImmutableExpression.FP64Literal.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -210,7 +213,7 @@ public static ImmutableExpression.StrLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -230,7 +233,7 @@ public static ImmutableExpression.BinaryLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -250,7 +253,7 @@ public static ImmutableExpression.TimestampLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -270,7 +273,7 @@ public static ImmutableExpression.TimeLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -290,7 +293,7 @@ public static ImmutableExpression.DateLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -310,7 +313,7 @@ public static ImmutableExpression.TimestampTZLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -332,7 +335,7 @@ public static ImmutableExpression.PrecisionTimestampLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -354,7 +357,7 @@ public static ImmutableExpression.PrecisionTimestampTZLiteral.Builder builder() @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -376,7 +379,7 @@ public static ImmutableExpression.IntervalYearLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -402,7 +405,7 @@ public static ImmutableExpression.IntervalDayLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -434,7 +437,7 @@ public static ImmutableExpression.IntervalCompoundLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -454,12 +457,12 @@ public static ImmutableExpression.UUIDLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } public ByteString toBytes() { - ByteBuffer bb = ByteBuffer.allocate(16); + final ByteBuffer bb = ByteBuffer.allocate(16); bb.putLong(value().getMostSignificantBits()); bb.putLong(value().getLeastSignificantBits()); bb.flip(); @@ -482,7 +485,7 @@ public static ImmutableExpression.FixedCharLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -504,7 +507,7 @@ public static ImmutableExpression.VarCharLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -524,7 +527,7 @@ public static ImmutableExpression.FixedBinaryLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -548,7 +551,7 @@ public static ImmutableExpression.DecimalLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -571,7 +574,7 @@ public static ImmutableExpression.MapLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -593,7 +596,7 @@ public static ImmutableExpression.EmptyMapLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -613,7 +616,7 @@ public static ImmutableExpression.ListLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -633,7 +636,7 @@ public static ImmutableExpression.EmptyListLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -657,7 +660,7 @@ public static ImmutableExpression.StructLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -681,7 +684,7 @@ public static ImmutableExpression.UserDefinedLiteral.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -705,7 +708,7 @@ public static ImmutableExpression.Switch.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -729,7 +732,7 @@ abstract class IfThen implements Expression { @Override public Type getType() { - Type elseType = elseClause().getType(); + final Type elseType = elseClause().getType(); // If any of the clauses are nullable, the whole expression is also nullable. if (ifClauses().stream().anyMatch(clause -> clause.then().getType().nullable())) { @@ -744,7 +747,7 @@ public static ImmutableExpression.IfThen.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -779,7 +782,7 @@ public static ImmutableExpression.Cast.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -805,7 +808,7 @@ public static ImmutableExpression.ScalarFunctionInvocation.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -846,7 +849,7 @@ public static ImmutableExpression.WindowFunctionInvocation.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -858,7 +861,7 @@ enum WindowBoundsType { private final io.substrait.proto.Expression.WindowFunction.BoundsType proto; - WindowBoundsType(io.substrait.proto.Expression.WindowFunction.BoundsType proto) { + WindowBoundsType(final io.substrait.proto.Expression.WindowFunction.BoundsType proto) { this.proto = proto; } @@ -867,8 +870,8 @@ public io.substrait.proto.Expression.WindowFunction.BoundsType toProto() { } public static WindowBoundsType fromProto( - io.substrait.proto.Expression.WindowFunction.BoundsType proto) { - for (WindowBoundsType v : values()) { + final io.substrait.proto.Expression.WindowFunction.BoundsType proto) { + for (final WindowBoundsType v : values()) { if (v.proto == proto) { return v; } @@ -895,7 +898,7 @@ public static ImmutableExpression.SingleOrList.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -917,7 +920,7 @@ public static ImmutableExpression.MultiOrList.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -961,7 +964,7 @@ public static ImmutableExpression.SetPredicate.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -976,7 +979,7 @@ public static ImmutableExpression.ScalarSubquery.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -998,7 +1001,7 @@ public static ImmutableExpression.InPredicate.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } } @@ -1013,7 +1016,7 @@ enum PredicateOp { private final io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto; - PredicateOp(io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) { + PredicateOp(final io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) { this.proto = proto; } @@ -1022,8 +1025,8 @@ public io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp toProto() } public static PredicateOp fromProto( - io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) { - for (PredicateOp v : values()) { + final io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp proto) { + for (final PredicateOp v : values()) { if (v.proto == proto) { return v; } @@ -1040,7 +1043,7 @@ enum AggregationInvocation { private final io.substrait.proto.AggregateFunction.AggregationInvocation proto; - AggregationInvocation(io.substrait.proto.AggregateFunction.AggregationInvocation proto) { + AggregationInvocation(final io.substrait.proto.AggregateFunction.AggregationInvocation proto) { this.proto = proto; } @@ -1048,8 +1051,9 @@ public io.substrait.proto.AggregateFunction.AggregationInvocation toProto() { return proto; } - public static AggregationInvocation fromProto(AggregateFunction.AggregationInvocation proto) { - for (AggregationInvocation v : values()) { + public static AggregationInvocation fromProto( + final AggregateFunction.AggregationInvocation proto) { + for (final AggregationInvocation v : values()) { if (v.proto == proto) { return v; } @@ -1071,7 +1075,7 @@ enum AggregationPhase { private final io.substrait.proto.AggregationPhase proto; - AggregationPhase(io.substrait.proto.AggregationPhase proto) { + AggregationPhase(final io.substrait.proto.AggregationPhase proto) { this.proto = proto; } @@ -1079,8 +1083,8 @@ public io.substrait.proto.AggregationPhase toProto() { return proto; } - public static AggregationPhase fromProto(io.substrait.proto.AggregationPhase proto) { - for (AggregationPhase v : values()) { + public static AggregationPhase fromProto(final io.substrait.proto.AggregationPhase proto) { + for (final AggregationPhase v : values()) { if (v.proto == proto) { return v; } @@ -1099,7 +1103,7 @@ enum SortDirection { private final io.substrait.proto.SortField.SortDirection proto; - SortDirection(io.substrait.proto.SortField.SortDirection proto) { + SortDirection(final io.substrait.proto.SortField.SortDirection proto) { this.proto = proto; } @@ -1107,8 +1111,8 @@ public io.substrait.proto.SortField.SortDirection toProto() { return proto; } - public static SortDirection fromProto(io.substrait.proto.SortField.SortDirection proto) { - for (SortDirection v : values()) { + public static SortDirection fromProto(final io.substrait.proto.SortField.SortDirection proto) { + for (final SortDirection v : values()) { if (v.proto == proto) { return v; } @@ -1126,7 +1130,7 @@ enum FailureBehavior { private final io.substrait.proto.Expression.Cast.FailureBehavior proto; - FailureBehavior(io.substrait.proto.Expression.Cast.FailureBehavior proto) { + FailureBehavior(final io.substrait.proto.Expression.Cast.FailureBehavior proto) { this.proto = proto; } @@ -1135,8 +1139,8 @@ public io.substrait.proto.Expression.Cast.FailureBehavior toProto() { } public static FailureBehavior fromProto( - io.substrait.proto.Expression.Cast.FailureBehavior proto) { - for (FailureBehavior v : values()) { + final io.substrait.proto.Expression.Cast.FailureBehavior proto) { + for (final FailureBehavior v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index adf157d7b..7ed609d9e 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -21,58 +21,58 @@ public class ExpressionCreator { private ExpressionCreator() {} - public static Expression.NullLiteral typedNull(Type t) { + public static Expression.NullLiteral typedNull(final Type t) { return Expression.NullLiteral.builder().type(t).build(); } - public static Expression.BoolLiteral bool(boolean nullable, boolean value) { + public static Expression.BoolLiteral bool(final boolean nullable, final boolean value) { return Expression.BoolLiteral.builder().nullable(nullable).value(value).build(); } - public static Expression.I8Literal i8(boolean nullable, int value) { + public static Expression.I8Literal i8(final boolean nullable, final int value) { return Expression.I8Literal.builder().nullable(nullable).value(value).build(); } - public static Expression.I16Literal i16(boolean nullable, int value) { + public static Expression.I16Literal i16(final boolean nullable, final int value) { return Expression.I16Literal.builder().nullable(nullable).value(value).build(); } - public static Expression.I32Literal i32(boolean nullable, int value) { + public static Expression.I32Literal i32(final boolean nullable, final int value) { return Expression.I32Literal.builder().nullable(nullable).value(value).build(); } - public static Expression.I64Literal i64(boolean nullable, long value) { + public static Expression.I64Literal i64(final boolean nullable, final long value) { return Expression.I64Literal.builder().nullable(nullable).value(value).build(); } - public static Expression.FP32Literal fp32(boolean nullable, float value) { + public static Expression.FP32Literal fp32(final boolean nullable, final float value) { return Expression.FP32Literal.builder().nullable(nullable).value(value).build(); } - public static Expression.FP64Literal fp64(boolean nullable, double value) { + public static Expression.FP64Literal fp64(final boolean nullable, final double value) { return Expression.FP64Literal.builder().nullable(nullable).value(value).build(); } - public static Expression.StrLiteral string(boolean nullable, String value) { + public static Expression.StrLiteral string(final boolean nullable, final String value) { return Expression.StrLiteral.builder().nullable(nullable).value(value).build(); } - public static Expression.BinaryLiteral binary(boolean nullable, ByteString value) { + public static Expression.BinaryLiteral binary(final boolean nullable, final ByteString value) { return Expression.BinaryLiteral.builder().nullable(nullable).value(value).build(); } - public static Expression.BinaryLiteral binary(boolean nullable, byte[] value) { + public static Expression.BinaryLiteral binary(final boolean nullable, final byte[] value) { return Expression.BinaryLiteral.builder() .nullable(nullable) .value(ByteString.copyFrom(value)) .build(); } - public static Expression.DateLiteral date(boolean nullable, int value) { + public static Expression.DateLiteral date(final boolean nullable, final int value) { return Expression.DateLiteral.builder().nullable(nullable).value(value).build(); } - public static Expression.TimeLiteral time(boolean nullable, long value) { + public static Expression.TimeLiteral time(final boolean nullable, final long value) { return Expression.TimeLiteral.builder().nullable(nullable).value(value).build(); } @@ -80,7 +80,7 @@ public static Expression.TimeLiteral time(boolean nullable, long value) { * @deprecated Timestamp is deprecated in favor of PrecisionTimestamp */ @Deprecated - public static Expression.TimestampLiteral timestamp(boolean nullable, long value) { + public static Expression.TimestampLiteral timestamp(final boolean nullable, final long value) { return Expression.TimestampLiteral.builder().nullable(nullable).value(value).build(); } @@ -88,8 +88,9 @@ public static Expression.TimestampLiteral timestamp(boolean nullable, long value * @deprecated Timestamp is deprecated in favor of PrecisionTimestamp */ @Deprecated - public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateTime value) { - long epochMicro = + public static Expression.TimestampLiteral timestamp( + final boolean nullable, final LocalDateTime value) { + final long epochMicro = TimeUnit.SECONDS.toMicros(value.toEpochSecond(ZoneOffset.UTC)) + TimeUnit.NANOSECONDS.toMicros(value.toLocalTime().getNano()); return timestamp(nullable, epochMicro); @@ -100,14 +101,14 @@ public static Expression.TimestampLiteral timestamp(boolean nullable, LocalDateT */ @Deprecated public static Expression.TimestampLiteral timestamp( - boolean nullable, - int year, - int month, - int dayOfMonth, - int hour, - int minute, - int second, - int micros) { + final boolean nullable, + final int year, + final int month, + final int dayOfMonth, + final int hour, + final int minute, + final int second, + final int micros) { return timestamp( nullable, LocalDateTime.of(year, month, dayOfMonth, hour, minute, second) @@ -118,7 +119,8 @@ public static Expression.TimestampLiteral timestamp( * @deprecated TimestampTZ is deprecated in favor of PrecisionTimestampTZ */ @Deprecated - public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, long value) { + public static Expression.TimestampTZLiteral timestampTZ( + final boolean nullable, final long value) { return Expression.TimestampTZLiteral.builder().nullable(nullable).value(value).build(); } @@ -126,15 +128,16 @@ public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, long v * @deprecated TimestampTZ is deprecated in favor of PrecisionTimestampTZ */ @Deprecated - public static Expression.TimestampTZLiteral timestampTZ(boolean nullable, Instant value) { - long epochMicro = + public static Expression.TimestampTZLiteral timestampTZ( + final boolean nullable, final Instant value) { + final long epochMicro = TimeUnit.SECONDS.toMicros(value.getEpochSecond()) + TimeUnit.NANOSECONDS.toMicros(value.getNano()); return timestampTZ(nullable, epochMicro); } public static Expression.PrecisionTimestampLiteral precisionTimestamp( - boolean nullable, long value, int precision) { + final boolean nullable, final long value, final int precision) { return Expression.PrecisionTimestampLiteral.builder() .nullable(nullable) .value(value) @@ -143,7 +146,7 @@ public static Expression.PrecisionTimestampLiteral precisionTimestamp( } public static Expression.PrecisionTimestampTZLiteral precisionTimestampTZ( - boolean nullable, long value, int precision) { + final boolean nullable, final long value, final int precision) { return Expression.PrecisionTimestampTZLiteral.builder() .nullable(nullable) .value(value) @@ -152,7 +155,7 @@ public static Expression.PrecisionTimestampTZLiteral precisionTimestampTZ( } public static Expression.IntervalYearLiteral intervalYear( - boolean nullable, int years, int months) { + final boolean nullable, final int years, final int months) { return Expression.IntervalYearLiteral.builder() .nullable(nullable) .years(years) @@ -160,12 +163,17 @@ public static Expression.IntervalYearLiteral intervalYear( .build(); } - public static Expression.IntervalDayLiteral intervalDay(boolean nullable, int days, int seconds) { + public static Expression.IntervalDayLiteral intervalDay( + final boolean nullable, final int days, final int seconds) { return intervalDay(nullable, days, seconds, 0, 0); } public static Expression.IntervalDayLiteral intervalDay( - boolean nullable, int days, int seconds, long subseconds, int precision) { + final boolean nullable, + final int days, + final int seconds, + final long subseconds, + final int precision) { return Expression.IntervalDayLiteral.builder() .nullable(nullable) .days(days) @@ -176,13 +184,13 @@ public static Expression.IntervalDayLiteral intervalDay( } public static Expression.IntervalCompoundLiteral intervalCompound( - boolean nullable, - int years, - int months, - int days, - int seconds, - long subseconds, - int precision) { + final boolean nullable, + final int years, + final int months, + final int days, + final int seconds, + final long subseconds, + final int precision) { return Expression.IntervalCompoundLiteral.builder() .nullable(nullable) .years(years) @@ -194,31 +202,34 @@ public static Expression.IntervalCompoundLiteral intervalCompound( .build(); } - public static Expression.UUIDLiteral uuid(boolean nullable, ByteString uuid) { - ByteBuffer bb = uuid.asReadOnlyByteBuffer(); + public static Expression.UUIDLiteral uuid(final boolean nullable, final ByteString uuid) { + final ByteBuffer bb = uuid.asReadOnlyByteBuffer(); return Expression.UUIDLiteral.builder() .nullable(nullable) .value(new UUID(bb.getLong(), bb.getLong())) .build(); } - public static Expression.UUIDLiteral uuid(boolean nullable, UUID uuid) { + public static Expression.UUIDLiteral uuid(final boolean nullable, final UUID uuid) { return Expression.UUIDLiteral.builder().nullable(nullable).value(uuid).build(); } - public static Expression.FixedCharLiteral fixedChar(boolean nullable, String str) { + public static Expression.FixedCharLiteral fixedChar(final boolean nullable, final String str) { return Expression.FixedCharLiteral.builder().nullable(nullable).value(str).build(); } - public static Expression.VarCharLiteral varChar(boolean nullable, String str, int len) { + public static Expression.VarCharLiteral varChar( + final boolean nullable, final String str, final int len) { return Expression.VarCharLiteral.builder().nullable(nullable).value(str).length(len).build(); } - public static Expression.FixedBinaryLiteral fixedBinary(boolean nullable, ByteString bytes) { + public static Expression.FixedBinaryLiteral fixedBinary( + final boolean nullable, final ByteString bytes) { return Expression.FixedBinaryLiteral.builder().nullable(nullable).value(bytes).build(); } - public static Expression.FixedBinaryLiteral fixedBinary(boolean nullable, byte[] bytes) { + public static Expression.FixedBinaryLiteral fixedBinary( + final boolean nullable, final byte[] bytes) { return Expression.FixedBinaryLiteral.builder() .nullable(nullable) .value(ByteString.copyFrom(bytes)) @@ -226,7 +237,7 @@ public static Expression.FixedBinaryLiteral fixedBinary(boolean nullable, byte[] } public static Expression.DecimalLiteral decimal( - boolean nullable, ByteString value, int precision, int scale) { + final boolean nullable, final ByteString value, final int precision, final int scale) { return Expression.DecimalLiteral.builder() .nullable(nullable) .value(value) @@ -236,8 +247,8 @@ public static Expression.DecimalLiteral decimal( } public static Expression.DecimalLiteral decimal( - boolean nullable, BigDecimal value, int precision, int scale) { - byte[] twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16); + final boolean nullable, final BigDecimal value, final int precision, final int scale) { + final byte[] twosComplement = DecimalUtil.encodeDecimalIntoBytes(value, scale, 16); return Expression.DecimalLiteral.builder() .nullable(nullable) @@ -248,12 +259,12 @@ public static Expression.DecimalLiteral decimal( } public static Expression.MapLiteral map( - boolean nullable, Map values) { + final boolean nullable, final Map values) { return Expression.MapLiteral.builder().nullable(nullable).putAllValues(values).build(); } public static Expression.EmptyMapLiteral emptyMap( - boolean nullable, Type keyType, Type valueType) { + final boolean nullable, final Type keyType, final Type valueType) { return Expression.EmptyMapLiteral.builder() .keyType(keyType) .valueType(valueType) @@ -261,33 +272,36 @@ public static Expression.EmptyMapLiteral emptyMap( .build(); } - public static Expression.ListLiteral list(boolean nullable, Expression.Literal... values) { + public static Expression.ListLiteral list( + final boolean nullable, final Expression.Literal... values) { return Expression.ListLiteral.builder().nullable(nullable).addValues(values).build(); } public static Expression.ListLiteral list( - boolean nullable, Iterable values) { + final boolean nullable, final Iterable values) { return Expression.ListLiteral.builder().nullable(nullable).addAllValues(values).build(); } - public static Expression.EmptyListLiteral emptyList(boolean listNullable, Type elementType) { + public static Expression.EmptyListLiteral emptyList( + final boolean listNullable, final Type elementType) { return Expression.EmptyListLiteral.builder() .elementType(elementType) .nullable(listNullable) .build(); } - public static Expression.StructLiteral struct(boolean nullable, Expression.Literal... values) { + public static Expression.StructLiteral struct( + final boolean nullable, final Expression.Literal... values) { return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build(); } public static Expression.StructLiteral struct( - boolean nullable, Iterable values) { + final boolean nullable, final Iterable values) { return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build(); } public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String urn, String name, Any value) { + final boolean nullable, final String urn, final String name, final Any value) { return Expression.UserDefinedLiteral.builder() .nullable(nullable) .urn(urn) @@ -297,7 +311,9 @@ public static Expression.UserDefinedLiteral userDefinedLiteral( } public static Expression.Switch switchStatement( - Expression match, Expression defaultExpression, Expression.SwitchClause... conditionClauses) { + final Expression match, + final Expression defaultExpression, + final Expression.SwitchClause... conditionClauses) { return Expression.Switch.builder() .match(match) .defaultClause(defaultExpression) @@ -306,9 +322,9 @@ public static Expression.Switch switchStatement( } public static Expression.Switch switchStatement( - Expression match, - Expression defaultExpression, - Iterable conditionClauses) { + final Expression match, + final Expression defaultExpression, + final Iterable conditionClauses) { return Expression.Switch.builder() .match(match) .defaultClause(defaultExpression) @@ -317,7 +333,7 @@ public static Expression.Switch switchStatement( } public static Expression.SwitchClause switchClause( - Expression.Literal expectedValue, Expression resultExpression) { + final Expression.Literal expectedValue, final Expression resultExpression) { return Expression.SwitchClause.builder() .condition(expectedValue) .then(resultExpression) @@ -325,7 +341,7 @@ public static Expression.SwitchClause switchClause( } public static Expression.IfThen ifThenStatement( - Expression elseExpression, Expression.IfClause... conditionClauses) { + final Expression elseExpression, final Expression.IfClause... conditionClauses) { return Expression.IfThen.builder() .elseClause(elseExpression) .addIfClauses(conditionClauses) @@ -333,7 +349,8 @@ public static Expression.IfThen ifThenStatement( } public static Expression.IfThen ifThenStatement( - Expression elseExpression, Iterable conditionClauses) { + final Expression elseExpression, + final Iterable conditionClauses) { return Expression.IfThen.builder() .elseClause(elseExpression) .addAllIfClauses(conditionClauses) @@ -341,7 +358,7 @@ public static Expression.IfThen ifThenStatement( } public static Expression.IfClause ifThenClause( - Expression conditionExpression, Expression resultExpression) { + final Expression conditionExpression, final Expression resultExpression) { return Expression.IfClause.builder() .condition(conditionExpression) .then(resultExpression) @@ -349,9 +366,9 @@ public static Expression.IfClause ifThenClause( } public static Expression.ScalarFunctionInvocation scalarFunction( - SimpleExtension.ScalarFunctionVariant declaration, - Type outputType, - FunctionArg... arguments) { + final SimpleExtension.ScalarFunctionVariant declaration, + final Type outputType, + final FunctionArg... arguments) { return scalarFunction(declaration, outputType, Arrays.asList(arguments)); } @@ -360,9 +377,9 @@ public static Expression.ScalarFunctionInvocation scalarFunction( * e.g. options */ public static Expression.ScalarFunctionInvocation scalarFunction( - SimpleExtension.ScalarFunctionVariant declaration, - Type outputType, - Iterable arguments) { + final SimpleExtension.ScalarFunctionVariant declaration, + final Type outputType, + final Iterable arguments) { return Expression.ScalarFunctionInvocation.builder() .declaration(declaration) .outputType(outputType) @@ -375,12 +392,12 @@ public static Expression.ScalarFunctionInvocation scalarFunction( * options */ public static AggregateFunctionInvocation aggregateFunction( - SimpleExtension.AggregateFunctionVariant declaration, - Type outputType, - Expression.AggregationPhase phase, - List sort, - Expression.AggregationInvocation invocation, - Iterable arguments) { + final SimpleExtension.AggregateFunctionVariant declaration, + final Type outputType, + final Expression.AggregationPhase phase, + final List sort, + final Expression.AggregationInvocation invocation, + final Iterable arguments) { return AggregateFunctionInvocation.builder() .declaration(declaration) .outputType(outputType) @@ -392,12 +409,12 @@ public static AggregateFunctionInvocation aggregateFunction( } public static AggregateFunctionInvocation aggregateFunction( - SimpleExtension.AggregateFunctionVariant declaration, - Type outputType, - Expression.AggregationPhase phase, - List sort, - Expression.AggregationInvocation invocation, - FunctionArg... arguments) { + final SimpleExtension.AggregateFunctionVariant declaration, + final Type outputType, + final Expression.AggregationPhase phase, + final List sort, + final Expression.AggregationInvocation invocation, + final FunctionArg... arguments) { return aggregateFunction( declaration, outputType, phase, sort, invocation, Arrays.asList(arguments)); } @@ -407,16 +424,16 @@ public static AggregateFunctionInvocation aggregateFunction( * e.g. options */ public static Expression.WindowFunctionInvocation windowFunction( - SimpleExtension.WindowFunctionVariant declaration, - Type outputType, - Expression.AggregationPhase phase, - List sort, - Expression.AggregationInvocation invocation, - List partitionBy, - Expression.WindowBoundsType boundsType, - WindowBound lowerBound, - WindowBound upperBound, - Iterable arguments) { + final SimpleExtension.WindowFunctionVariant declaration, + final Type outputType, + final Expression.AggregationPhase phase, + final List sort, + final Expression.AggregationInvocation invocation, + final List partitionBy, + final Expression.WindowBoundsType boundsType, + final WindowBound lowerBound, + final WindowBound upperBound, + final Iterable arguments) { return Expression.WindowFunctionInvocation.builder() .declaration(declaration) .outputType(outputType) @@ -436,14 +453,14 @@ public static Expression.WindowFunctionInvocation windowFunction( * other parameters, e.g. options */ public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunction( - SimpleExtension.WindowFunctionVariant declaration, - Type outputType, - Expression.AggregationPhase phase, - Expression.AggregationInvocation invocation, - Expression.WindowBoundsType boundsType, - WindowBound lowerBound, - WindowBound upperBound, - Iterable arguments) { + final SimpleExtension.WindowFunctionVariant declaration, + final Type outputType, + final Expression.AggregationPhase phase, + final Expression.AggregationInvocation invocation, + final Expression.WindowBoundsType boundsType, + final WindowBound lowerBound, + final WindowBound upperBound, + final Iterable arguments) { return ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() .declaration(declaration) .outputType(outputType) @@ -457,16 +474,16 @@ public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFun } public static Expression.WindowFunctionInvocation windowFunction( - SimpleExtension.WindowFunctionVariant declaration, - Type outputType, - Expression.AggregationPhase phase, - List sort, - Expression.AggregationInvocation invocation, - List partitionBy, - Expression.WindowBoundsType boundsType, - WindowBound lowerBound, - WindowBound upperBound, - FunctionArg... arguments) { + final SimpleExtension.WindowFunctionVariant declaration, + final Type outputType, + final Expression.AggregationPhase phase, + final List sort, + final Expression.AggregationInvocation invocation, + final List partitionBy, + final Expression.WindowBoundsType boundsType, + final WindowBound lowerBound, + final WindowBound upperBound, + final FunctionArg... arguments) { return windowFunction( declaration, outputType, @@ -481,7 +498,9 @@ public static Expression.WindowFunctionInvocation windowFunction( } public static Expression cast( - Type type, Expression expression, Expression.FailureBehavior failureBehavior) { + final Type type, + final Expression expression, + final Expression.FailureBehavior failureBehavior) { return Expression.Cast.builder() .type(type) .input(expression) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 7a4fdce5d..081a03f15 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -31,7 +31,7 @@ public static ImmutableFieldReference.Builder builder() { @Override public R accept( - ExpressionVisitor visitor, C context) throws E { + final ExpressionVisitor visitor, final C context) throws E { return visitor.visit(this, context); } @@ -45,12 +45,12 @@ public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } - public FieldReference dereferenceStruct(int index) { - Type newType = StructFieldFinder.getReferencedType(type(), index); + public FieldReference dereferenceStruct(final int index) { + final Type newType = StructFieldFinder.getReferencedType(type(), index); return dereference(newType, StructField.of(index)); } - private FieldReference dereference(Type newType, ReferenceSegment nextSegment) { + private FieldReference dereference(final Type newType, final ReferenceSegment nextSegment) { return ImmutableFieldReference.builder() .type(newType) .addSegments(nextSegment) @@ -59,17 +59,17 @@ private FieldReference dereference(Type newType, ReferenceSegment nextSegment) { .build(); } - public FieldReference dereferenceList(int index) { - Type newType = ListIndexFinder.getReferencedType(type(), index); + public FieldReference dereferenceList(final int index) { + final Type newType = ListIndexFinder.getReferencedType(type(), index); return dereference(newType, ListElement.of(index)); } - public FieldReference dereferenceMap(Literal mapKey) { - Type newType = MapKeyFinder.getReferencedType(type(), mapKey.getType()); + public FieldReference dereferenceMap(final Literal mapKey) { + final Type newType = MapKeyFinder.getReferencedType(type(), mapKey.getType()); return dereference(newType, MapKey.of(mapKey)); } - public static FieldReference newMapReference(Literal mapKey, Expression expression) { + public static FieldReference newMapReference(final Literal mapKey, final Expression expression) { return ImmutableFieldReference.builder() .addSegments(MapKey.of(mapKey)) .inputExpression(expression) @@ -77,7 +77,7 @@ public static FieldReference newMapReference(Literal mapKey, Expression expressi .build(); } - public static FieldReference newListReference(int index, Expression expression) { + public static FieldReference newListReference(final int index, final Expression expression) { return ImmutableFieldReference.builder() .addSegments(ListElement.of(index)) .inputExpression(expression) @@ -85,7 +85,7 @@ public static FieldReference newListReference(int index, Expression expression) .build(); } - public static FieldReference newStructReference(int index, Expression expression) { + public static FieldReference newStructReference(final int index, final Expression expression) { return ImmutableFieldReference.builder() .addSegments(StructField.of(index)) .inputExpression(expression) @@ -93,7 +93,7 @@ public static FieldReference newStructReference(int index, Expression expression .build(); } - public static FieldReference newRootStructReference(int index, Type knownType) { + public static FieldReference newRootStructReference(final int index, final Type knownType) { return ImmutableFieldReference.builder() .addSegments(StructField.of(index)) .type(knownType) @@ -101,7 +101,7 @@ public static FieldReference newRootStructReference(int index, Type knownType) { } public static FieldReference newRootStructOuterReference( - int index, Type knownType, int stepsOut) { + final int index, final Type knownType, final int stepsOut) { return ImmutableFieldReference.builder() .addSegments(StructField.of(index)) .type(knownType) @@ -109,16 +109,16 @@ public static FieldReference newRootStructOuterReference( .build(); } - public static FieldReference newInputRelReference(int index, Rel rel) { + public static FieldReference newInputRelReference(final int index, final Rel rel) { return newInputRelReference(index, Collections.singletonList(rel)); } - public static FieldReference newInputRelReference(int index, List rels) { + public static FieldReference newInputRelReference(final int index, final List rels) { int currentOffset = 0; - for (Rel r : rels) { - int relSize = r.getRecordType().fields().size(); + for (final Rel r : rels) { + final int relSize = r.getRecordType().fields().size(); if (index < currentOffset + relSize) { - Type referenceType = r.getRecordType().fields().get(index - currentOffset); + final Type referenceType = r.getRecordType().fields().get(index - currentOffset); return ImmutableFieldReference.builder() .addSegments(StructField.of(index)) .type(referenceType) @@ -146,22 +146,22 @@ public interface ReferenceSegment { public abstract static class StructField implements ReferenceSegment { public abstract int offset(); - public static StructField of(int index) { + public static StructField of(final int index) { return ImmutableStructField.builder().offset(index).build(); } @Override - public FieldReference apply(FieldReference reference) { + public FieldReference apply(final FieldReference reference) { return reference.dereferenceStruct(offset()); } @Override - public FieldReference constructOnExpression(Expression expr) { + public FieldReference constructOnExpression(final Expression expr) { return FieldReference.newStructReference(offset(), expr); } @Override - public FieldReference constructOnRoot(Type.Struct struct) { + public FieldReference constructOnRoot(final Type.Struct struct) { if (offset() >= struct.fields().size()) { throw new IllegalArgumentException( String.format( @@ -176,22 +176,22 @@ public FieldReference constructOnRoot(Type.Struct struct) { public abstract static class ListElement implements ReferenceSegment { public abstract int offset(); - public static ListElement of(int index) { + public static ListElement of(final int index) { return ImmutableListElement.builder().offset(index).build(); } @Override - public FieldReference apply(FieldReference reference) { + public FieldReference apply(final FieldReference reference) { return reference.dereferenceList(offset()); } @Override - public FieldReference constructOnExpression(Expression expr) { + public FieldReference constructOnExpression(final Expression expr) { return FieldReference.newListReference(offset(), expr); } @Override - public FieldReference constructOnRoot(Type.Struct struct) { + public FieldReference constructOnRoot(final Type.Struct struct) { throw new UnsupportedOperationException(); } } @@ -200,38 +200,40 @@ public FieldReference constructOnRoot(Type.Struct struct) { public abstract static class MapKey implements ReferenceSegment { public abstract Expression.Literal key(); - public static MapKey of(Expression.Literal key) { + public static MapKey of(final Expression.Literal key) { return ImmutableMapKey.builder().key(key).build(); } @Override - public FieldReference apply(FieldReference reference) { + public FieldReference apply(final FieldReference reference) { return reference.dereferenceMap(key()); } @Override - public FieldReference constructOnExpression(Expression expr) { + public FieldReference constructOnExpression(final Expression expr) { return FieldReference.newMapReference(key(), expr); } @Override - public FieldReference constructOnRoot(Type.Struct struct) { + public FieldReference constructOnRoot(final Type.Struct struct) { throw new UnsupportedOperationException(); } } public static FieldReference ofExpression( - Expression expression, List segments) { + final Expression expression, final List segments) { return of(null, expression, segments); } private static FieldReference of( - Type.Struct struct, Expression expression, List segments) { + final Type.Struct struct, + final Expression expression, + final List segments) { FieldReference reference = null; Collections.reverse(segments); for (int i = 0; i < segments.size(); i++) { if (i == 0) { - ReferenceSegment last = segments.get(0); + final ReferenceSegment last = segments.get(0); reference = struct == null ? last.constructOnExpression(expression) : last.constructOnRoot(struct); } else { @@ -242,7 +244,8 @@ private static FieldReference of( return reference; } - public static FieldReference ofRoot(Type.Struct struct, List segments) { + public static FieldReference ofRoot( + final Type.Struct struct, final List segments) { return of(struct, null, segments); } @@ -251,21 +254,21 @@ private static class StructFieldFinder private final int index; - private StructFieldFinder(int index) { + private StructFieldFinder(final int index) { super( "This visitor only supports retrieving struct types. Was applied to a non-struct type."); this.index = index; } @Override - public Type visit(Type.Struct expr) throws RuntimeException { + public Type visit(final Type.Struct expr) throws RuntimeException { if (expr.fields().size() < index) { throw new IllegalArgumentException("Undefined struct type."); } return expr.fields().get(index); } - public static Type getReferencedType(Type type, int index) { + public static Type getReferencedType(final Type type, final int index) { return type.accept(new StructFieldFinder(index)); } } @@ -278,11 +281,11 @@ private ListIndexFinder() { } @Override - public Type visit(Type.ListType expr) throws RuntimeException { + public Type visit(final Type.ListType expr) throws RuntimeException { return expr.elementType(); } - public static Type getReferencedType(Type type, int index) { + public static Type getReferencedType(final Type type, final int index) { return type.accept(new ListIndexFinder()); } } @@ -291,14 +294,14 @@ private static class MapKeyFinder extends TypeVisitor.TypeThrowsVisitor { } static FuncArgVisitor toProto( - TypeExpressionVisitor typeVisitor, - ExpressionVisitor + final TypeExpressionVisitor typeVisitor, + final ExpressionVisitor< + io.substrait.proto.Expression, EmptyVisitationContext, RuntimeException> expressionVisitor) { return new FuncArgVisitor() { @Override public FunctionArgument visitExpr( - SimpleExtension.Function fnDef, int argIdx, Expression e, EmptyVisitationContext context) + final SimpleExtension.Function fnDef, + final int argIdx, + final Expression e, + final EmptyVisitationContext context) throws RuntimeException { - io.substrait.proto.Expression pE = e.accept(expressionVisitor, context); + final io.substrait.proto.Expression pE = e.accept(expressionVisitor, context); return FunctionArgument.newBuilder().setValue(pE).build(); } @Override public FunctionArgument visitType( - SimpleExtension.Function fnDef, int argIdx, Type t, EmptyVisitationContext context) + final SimpleExtension.Function fnDef, + final int argIdx, + final Type t, + final EmptyVisitationContext context) throws RuntimeException { - io.substrait.proto.Type pTyp = t.accept(typeVisitor); + final io.substrait.proto.Type pTyp = t.accept(typeVisitor); return FunctionArgument.newBuilder().setType(pTyp).build(); } @Override public FunctionArgument visitEnumArg( - SimpleExtension.Function fnDef, int argIdx, EnumArg ea, EmptyVisitationContext context) + final SimpleExtension.Function fnDef, + final int argIdx, + final EnumArg ea, + final EmptyVisitationContext context) throws RuntimeException { FunctionArgument.Builder enumBldr = FunctionArgument.newBuilder(); @@ -75,13 +85,14 @@ class ProtoFrom { private final ProtoTypeConverter protoTypeConverter; public ProtoFrom( - ProtoExpressionConverter protoExprConverter, ProtoTypeConverter protoTypeConverter) { + final ProtoExpressionConverter protoExprConverter, + final ProtoTypeConverter protoTypeConverter) { this.protoExprConverter = protoExprConverter; this.protoTypeConverter = protoTypeConverter; } public FunctionArg convert( - SimpleExtension.Function funcDef, int argIdx, FunctionArgument fArg) { + final SimpleExtension.Function funcDef, final int argIdx, final FunctionArgument fArg) { switch (fArg.getArgTypeCase()) { case TYPE: return protoTypeConverter.from(fArg.getType()); @@ -89,9 +100,9 @@ public FunctionArg convert( return protoExprConverter.from(fArg.getValue()); case ENUM: { - SimpleExtension.EnumArgument enumArgDef = + final SimpleExtension.EnumArgument enumArgDef = (SimpleExtension.EnumArgument) funcDef.args().get(argIdx); - String optionValue = fArg.getEnum(); + final String optionValue = fArg.getEnum(); return EnumArg.of(enumArgDef, optionValue); } default: diff --git a/core/src/main/java/io/substrait/expression/WindowBound.java b/core/src/main/java/io/substrait/expression/WindowBound.java index d1403f881..d51d37d97 100644 --- a/core/src/main/java/io/substrait/expression/WindowBound.java +++ b/core/src/main/java/io/substrait/expression/WindowBound.java @@ -24,12 +24,12 @@ interface WindowBoundVisitor { abstract class Preceding implements WindowBound { public abstract long offset(); - public static Preceding of(long offset) { + public static Preceding of(final long offset) { return ImmutableWindowBound.Preceding.builder().offset(offset).build(); } @Override - public R accept(WindowBoundVisitor visitor) { + public R accept(final WindowBoundVisitor visitor) { return visitor.visit(this); } } @@ -38,12 +38,12 @@ public R accept(WindowBoundVisitor visitor) { abstract class Following implements WindowBound { public abstract long offset(); - public static Following of(long offset) { + public static Following of(final long offset) { return ImmutableWindowBound.Following.builder().offset(offset).build(); } @Override - public R accept(WindowBoundVisitor visitor) { + public R accept(final WindowBoundVisitor visitor) { return visitor.visit(this); } } @@ -51,7 +51,7 @@ public R accept(WindowBoundVisitor visitor) { @Value.Immutable abstract class CurrentRow implements WindowBound { @Override - public R accept(WindowBoundVisitor visitor) { + public R accept(final WindowBoundVisitor visitor) { return visitor.visit(this); } } @@ -59,7 +59,7 @@ public R accept(WindowBoundVisitor visitor) { @Value.Immutable abstract class Unbounded implements WindowBound { @Override - public R accept(WindowBoundVisitor visitor) { + public R accept(final WindowBoundVisitor visitor) { return visitor.visit(this); } } diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index caf145dfc..9703257e4 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -32,7 +32,7 @@ public class ExpressionProtoConverter protected final ExtensionCollector extensionCollector; public ExpressionProtoConverter( - ExtensionCollector extensionCollector, RelProtoConverter relProtoConverter) { + final ExtensionCollector extensionCollector, final RelProtoConverter relProtoConverter) { this.extensionCollector = extensionCollector; this.relProtoConverter = relProtoConverter; this.typeProtoConverter = new TypeProtoConverter(extensionCollector); @@ -46,117 +46,132 @@ public TypeProtoConverter getTypeProtoConverter() { return this.typeProtoConverter; } - public io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) { + public io.substrait.proto.Expression toProto( + final io.substrait.expression.Expression expression) { return expression.accept(this, EmptyVisitationContext.INSTANCE); } public List toProto( - List expressions) { + final List expressions) { return expressions.stream().map(this::toProto).collect(Collectors.toList()); } - protected io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) { + protected io.substrait.proto.Rel toProto(final io.substrait.relation.Rel rel) { return relProtoConverter.toProto(rel); } - protected io.substrait.proto.Type toProto(io.substrait.type.Type type) { + protected io.substrait.proto.Type toProto(final io.substrait.type.Type type) { return typeProtoConverter.toProto(type); } @Override public Expression visit( - io.substrait.expression.Expression.NullLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.NullLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNull(toProto(expr.type()))); } - private Expression lit(Consumer consumer) { - Expression.Literal.Builder builder = Expression.Literal.newBuilder(); + private Expression lit(final Consumer consumer) { + final Expression.Literal.Builder builder = Expression.Literal.newBuilder(); consumer.accept(builder); return Expression.newBuilder().setLiteral(builder).build(); } @Override public Expression visit( - io.substrait.expression.Expression.BoolLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.BoolLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setBoolean(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.I8Literal expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.I8Literal expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI8(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.I16Literal expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.I16Literal expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI16(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.I32Literal expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.I32Literal expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI32(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.I64Literal expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.I64Literal expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI64(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.FP32Literal expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.FP32Literal expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFp32(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.FP64Literal expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.FP64Literal expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFp64(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.StrLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.StrLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setString(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.BinaryLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.BinaryLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setBinary(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.TimeLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.TimeLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setTime(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.DateLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.DateLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setDate(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.TimestampLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.TimestampLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setTimestamp(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.TimestampTZLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.TimestampTZLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setTimestampTz(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.PrecisionTimestampLiteral expr, - EmptyVisitationContext context) { + final io.substrait.expression.Expression.PrecisionTimestampLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -170,8 +185,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.PrecisionTimestampTZLiteral expr, - EmptyVisitationContext context) { + final io.substrait.expression.Expression.PrecisionTimestampTZLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -185,7 +200,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.IntervalYearLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.IntervalYearLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -197,7 +213,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.IntervalDayLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.IntervalDayLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -211,8 +228,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.IntervalCompoundLiteral expr, - EmptyVisitationContext context) { + final io.substrait.expression.Expression.IntervalCompoundLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -232,19 +249,22 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.UUIDLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.UUIDLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setUuid(expr.toBytes())); } @Override public Expression visit( - io.substrait.expression.Expression.FixedCharLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.FixedCharLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFixedChar(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.VarCharLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.VarCharLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -256,13 +276,15 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.FixedBinaryLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.FixedBinaryLiteral expr, + final EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFixedBinary(expr.value())); } @Override public Expression visit( - io.substrait.expression.Expression.DecimalLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.DecimalLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -275,15 +297,16 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.MapLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.MapLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> { - List keyValues = + final List keyValues = expr.values().entrySet().stream() .map( e -> { - Expression.Literal key = toLiteral(e.getKey()); - Expression.Literal value = toLiteral(e.getValue()); + final Expression.Literal key = toLiteral(e.getKey()); + final Expression.Literal value = toLiteral(e.getValue()); return Expression.Literal.Map.KeyValue.newBuilder() .setKey(key) .setValue(value) @@ -297,10 +320,11 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.EmptyMapLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.EmptyMapLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> { - Type protoMapType = toProto(expr.getType()); + final Type protoMapType = toProto(expr.getType()); bldr.setEmptyMap(protoMapType.getMap()) // For empty maps, the Literal message's own nullable field should be ignored // in favor of the nullability of the Type.Map in the literal's @@ -313,10 +337,11 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.ListLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.ListLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> { - List values = + final List values = expr.values().stream() .map(this::toLiteral) .collect(java.util.stream.Collectors.toList()); @@ -327,11 +352,12 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.EmptyListLiteral expr, EmptyVisitationContext context) + final io.substrait.expression.Expression.EmptyListLiteral expr, + final EmptyVisitationContext context) throws RuntimeException { return lit( builder -> { - Type protoListType = toProto(expr.getType()); + final Type protoListType = toProto(expr.getType()); builder .setEmptyList(protoListType.getList()) // For empty lists, the Literal message's own nullable field should be ignored @@ -345,10 +371,11 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.StructLiteral expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.StructLiteral expr, + final EmptyVisitationContext context) { return lit( bldr -> { - List values = + final List values = expr.fields().stream() .map(this::toLiteral) .collect(java.util.stream.Collectors.toList()); @@ -359,8 +386,9 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { - int typeReference = + final io.substrait.expression.Expression.UserDefinedLiteral expr, + final EmptyVisitationContext context) { + final int typeReference = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return lit( bldr -> { @@ -371,22 +399,22 @@ public Expression visit( .setTypeReference(typeReference) .setValue(Any.parseFrom(expr.value()))) .build(); - } catch (InvalidProtocolBufferException e) { + } catch (final InvalidProtocolBufferException e) { throw new IllegalStateException(e); } }); } - private Expression.Literal toLiteral(io.substrait.expression.Expression expression) { - Expression e = toProto(expression); + private Expression.Literal toLiteral(final io.substrait.expression.Expression expression) { + final Expression e = toProto(expression); assert e.getRexTypeCase() == Expression.RexTypeCase.LITERAL; return e.getLiteral(); } @Override public Expression visit( - io.substrait.expression.Expression.Switch expr, EmptyVisitationContext context) { - List clauses = + final io.substrait.expression.Expression.Switch expr, final EmptyVisitationContext context) { + final List clauses = expr.switchClauses().stream() .map( s -> @@ -406,8 +434,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.IfThen expr, EmptyVisitationContext context) { - List clauses = + final io.substrait.expression.Expression.IfThen expr, final EmptyVisitationContext context) { + final List clauses = expr.ifClauses().stream() .map( s -> @@ -424,10 +452,10 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.ScalarFunctionInvocation expr, - EmptyVisitationContext context) { + final io.substrait.expression.Expression.ScalarFunctionInvocation expr, + final EmptyVisitationContext context) { - FunctionArg.FuncArgVisitor + final FunctionArg.FuncArgVisitor argVisitor = FunctionArg.toProto(typeProtoConverter, this); return Expression.newBuilder() @@ -446,7 +474,7 @@ public Expression visit( .build(); } - public static FunctionOption from(io.substrait.expression.FunctionOption option) { + public static FunctionOption from(final io.substrait.expression.FunctionOption option) { return FunctionOption.newBuilder() .setName(option.getName()) .addAllPreference(option.values()) @@ -455,7 +483,7 @@ public static FunctionOption from(io.substrait.expression.FunctionOption option) @Override public Expression visit( - io.substrait.expression.Expression.Cast expr, EmptyVisitationContext context) { + final io.substrait.expression.Expression.Cast expr, final EmptyVisitationContext context) { return Expression.newBuilder() .setCast( Expression.Cast.newBuilder() @@ -467,7 +495,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.SingleOrList expr, EmptyVisitationContext context) + final io.substrait.expression.Expression.SingleOrList expr, + final EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSingularOrList( @@ -479,7 +508,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.MultiOrList expr, EmptyVisitationContext context) + final io.substrait.expression.Expression.MultiOrList expr, + final EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setMultiOrList( @@ -497,30 +527,30 @@ public Expression visit( } @Override - public Expression visit(FieldReference expr, EmptyVisitationContext context) { + public Expression visit(final FieldReference expr, final EmptyVisitationContext context) { Expression.ReferenceSegment seg = null; - for (FieldReference.ReferenceSegment segment : expr.segments()) { - Expression.ReferenceSegment.Builder protoSegment; + for (final FieldReference.ReferenceSegment segment : expr.segments()) { + final Expression.ReferenceSegment.Builder protoSegment; if (segment instanceof FieldReference.StructField) { - FieldReference.StructField f = (FieldReference.StructField) segment; - Expression.ReferenceSegment.StructField.Builder bldr = + final FieldReference.StructField f = (FieldReference.StructField) segment; + final Expression.ReferenceSegment.StructField.Builder bldr = Expression.ReferenceSegment.StructField.newBuilder().setField(f.offset()); if (seg != null) { bldr.setChild(seg); } protoSegment = Expression.ReferenceSegment.newBuilder().setStructField(bldr); } else if (segment instanceof FieldReference.ListElement) { - FieldReference.ListElement f = (FieldReference.ListElement) segment; - Expression.ReferenceSegment.ListElement.Builder bldr = + final FieldReference.ListElement f = (FieldReference.ListElement) segment; + final Expression.ReferenceSegment.ListElement.Builder bldr = Expression.ReferenceSegment.ListElement.newBuilder().setOffset(f.offset()); if (seg != null) { bldr.setChild(seg); } protoSegment = Expression.ReferenceSegment.newBuilder().setListElement(bldr); } else if (segment instanceof FieldReference.MapKey) { - FieldReference.MapKey f = (FieldReference.MapKey) segment; - Expression.ReferenceSegment.MapKey.Builder bldr = + final FieldReference.MapKey f = (FieldReference.MapKey) segment; + final Expression.ReferenceSegment.MapKey.Builder bldr = Expression.ReferenceSegment.MapKey.newBuilder().setMapKey(toLiteral(f.key())); if (seg != null) { bldr.setChild(seg); @@ -532,7 +562,7 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { seg = protoSegment.build(); } - Expression.FieldReference.Builder out = + final Expression.FieldReference.Builder out = Expression.FieldReference.newBuilder().setDirectReference(seg); if (expr.inputExpression().isPresent()) { @@ -550,7 +580,8 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { @Override public Expression visit( - io.substrait.expression.Expression.SetPredicate expr, EmptyVisitationContext context) + final io.substrait.expression.Expression.SetPredicate expr, + final EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSubquery( @@ -566,7 +597,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.ScalarSubquery expr, EmptyVisitationContext context) + final io.substrait.expression.Expression.ScalarSubquery expr, + final EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSubquery( @@ -579,7 +611,8 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.InPredicate expr, EmptyVisitationContext context) + final io.substrait.expression.Expression.InPredicate expr, + final EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSubquery( @@ -595,20 +628,20 @@ public Expression visit( @Override public Expression visit( - io.substrait.expression.Expression.WindowFunctionInvocation expr, - EmptyVisitationContext context) + final io.substrait.expression.Expression.WindowFunctionInvocation expr, + final EmptyVisitationContext context) throws RuntimeException { - FunctionArg.FuncArgVisitor + final FunctionArg.FuncArgVisitor argVisitor = FunctionArg.toProto(typeProtoConverter, this); - List args = + final List args = expr.arguments().stream() .map(a -> a.accept(expr.declaration(), 0, argVisitor, context)) .collect(java.util.stream.Collectors.toList()); - Type outputType = toProto(expr.getType()); + final Type outputType = toProto(expr.getType()); - List partitionExprs = toProto(expr.partitionBy()); + final List partitionExprs = toProto(expr.partitionBy()); - List sortFields = + final List sortFields = expr.sort().stream() .map( s -> @@ -618,8 +651,8 @@ public Expression visit( .build()) .collect(java.util.stream.Collectors.toList()); - Expression.WindowFunction.Bound lowerBound = BoundConverter.convert(expr.lowerBound()); - Expression.WindowFunction.Bound upperBound = BoundConverter.convert(expr.upperBound()); + final Expression.WindowFunction.Bound lowerBound = BoundConverter.convert(expr.lowerBound()); + final Expression.WindowFunction.Bound upperBound = BoundConverter.convert(expr.upperBound()); return Expression.newBuilder() .setWindowFunction( @@ -645,14 +678,14 @@ public static class BoundConverter implements WindowBound.WindowBoundVisitor { private static final BoundConverter TO_BOUND_VISITOR = new BoundConverter(); - public static Expression.WindowFunction.Bound convert(WindowBound bound) { + public static Expression.WindowFunction.Bound convert(final WindowBound bound) { return bound.accept(TO_BOUND_VISITOR); } private BoundConverter() {} @Override - public Expression.WindowFunction.Bound visit(WindowBound.Preceding preceding) { + public Expression.WindowFunction.Bound visit(final WindowBound.Preceding preceding) { return Expression.WindowFunction.Bound.newBuilder() .setPreceding( Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(preceding.offset())) @@ -660,7 +693,7 @@ public Expression.WindowFunction.Bound visit(WindowBound.Preceding preceding) { } @Override - public Expression.WindowFunction.Bound visit(WindowBound.Following following) { + public Expression.WindowFunction.Bound visit(final WindowBound.Following following) { return Expression.WindowFunction.Bound.newBuilder() .setFollowing( Expression.WindowFunction.Bound.Following.newBuilder().setOffset(following.offset())) @@ -668,14 +701,14 @@ public Expression.WindowFunction.Bound visit(WindowBound.Following following) { } @Override - public Expression.WindowFunction.Bound visit(WindowBound.CurrentRow currentRow) { + public Expression.WindowFunction.Bound visit(final WindowBound.CurrentRow currentRow) { return Expression.WindowFunction.Bound.newBuilder() .setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance()) .build(); } @Override - public Expression.WindowFunction.Bound visit(WindowBound.Unbounded unbounded) { + public Expression.WindowFunction.Bound visit(final WindowBound.Unbounded unbounded) { return Expression.WindowFunction.Bound.newBuilder() .setUnbounded(Expression.WindowFunction.Bound.Unbounded.getDefaultInstance()) .build(); diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8f95cdf07..a7004e999 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -39,10 +39,10 @@ public class ProtoExpressionConverter { private final ProtoRelConverter protoRelConverter; public ProtoExpressionConverter( - ExtensionLookup lookup, - SimpleExtension.ExtensionCollection extensions, - Type.Struct rootType, - ProtoRelConverter relConverter) { + final ExtensionLookup lookup, + final SimpleExtension.ExtensionCollection extensions, + final Type.Struct rootType, + final ProtoRelConverter relConverter) { this.lookup = lookup; this.extensions = extensions; this.rootType = Objects.requireNonNull(rootType, "rootType"); @@ -50,8 +50,8 @@ public ProtoExpressionConverter( this.protoRelConverter = relConverter; } - public FieldReference from(io.substrait.proto.Expression.FieldReference reference) { - io.substrait.proto.Expression.FieldReference.ReferenceTypeCase refTypeCase = + public FieldReference from(final io.substrait.proto.Expression.FieldReference reference) { + final io.substrait.proto.Expression.FieldReference.ReferenceTypeCase refTypeCase = reference.getReferenceTypeCase(); if (refTypeCase == ReferenceTypeCase.MASKED_REFERENCE) { @@ -83,24 +83,24 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc private List getDirectReferenceSegments( io.substrait.proto.Expression.ReferenceSegment segment) { - List results = new ArrayList<>(); + final List results = new ArrayList<>(); while (segment != io.substrait.proto.Expression.ReferenceSegment.getDefaultInstance()) { final ReferenceSegment mappedSegment; switch (segment.getReferenceTypeCase()) { case MAP_KEY: - io.substrait.proto.Expression.ReferenceSegment.MapKey mapKey = segment.getMapKey(); + final io.substrait.proto.Expression.ReferenceSegment.MapKey mapKey = segment.getMapKey(); segment = mapKey.getChild(); mappedSegment = FieldReference.MapKey.of(from(mapKey.getMapKey())); break; case STRUCT_FIELD: - io.substrait.proto.Expression.ReferenceSegment.StructField structField = + final io.substrait.proto.Expression.ReferenceSegment.StructField structField = segment.getStructField(); segment = structField.getChild(); mappedSegment = FieldReference.StructField.of(structField.getField()); break; case LIST_ELEMENT: - io.substrait.proto.Expression.ReferenceSegment.ListElement listElement = + final io.substrait.proto.Expression.ReferenceSegment.ListElement listElement = segment.getListElement(); segment = listElement.getChild(); mappedSegment = FieldReference.ListElement.of(listElement.getOffset()); @@ -118,7 +118,7 @@ private List getDirectReferenceSegments( return results; } - public Expression from(io.substrait.proto.Expression expr) { + public Expression from(final io.substrait.proto.Expression expr) { switch (expr.getRexTypeCase()) { case LITERAL: return from(expr.getLiteral()); @@ -126,16 +126,17 @@ public Expression from(io.substrait.proto.Expression expr) { return from(expr.getSelection()); case SCALAR_FUNCTION: { - io.substrait.proto.Expression.ScalarFunction scalarFunction = expr.getScalarFunction(); - int functionReference = scalarFunction.getFunctionReference(); - SimpleExtension.ScalarFunctionVariant declaration = + final io.substrait.proto.Expression.ScalarFunction scalarFunction = + expr.getScalarFunction(); + final int functionReference = scalarFunction.getFunctionReference(); + final SimpleExtension.ScalarFunctionVariant declaration = lookup.getScalarFunction(functionReference, extensions); - FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); - List args = + final FunctionArg.ProtoFrom pF = new FunctionArg.ProtoFrom(this, protoTypeConverter); + final List args = IntStream.range(0, scalarFunction.getArgumentsCount()) .mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) .collect(Collectors.toList()); - List options = + final List options = scalarFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); @@ -150,8 +151,8 @@ public Expression from(io.substrait.proto.Expression expr) { return fromWindowFunction(expr.getWindowFunction()); case IF_THEN: { - io.substrait.proto.Expression.IfThen ifThen = expr.getIfThen(); - List clauses = + final io.substrait.proto.Expression.IfThen ifThen = expr.getIfThen(); + final List clauses = ifThen.getIfsList().stream() .map(t -> ExpressionCreator.ifThenClause(from(t.getIf()), from(t.getThen()))) .collect(Collectors.toList()); @@ -159,8 +160,9 @@ public Expression from(io.substrait.proto.Expression expr) { } case SWITCH_EXPRESSION: { - io.substrait.proto.Expression.SwitchExpression switchExpr = expr.getSwitchExpression(); - List clauses = + final io.substrait.proto.Expression.SwitchExpression switchExpr = + expr.getSwitchExpression(); + final List clauses = switchExpr.getIfsList().stream() .map(t -> ExpressionCreator.switchClause(from(t.getIf()), from(t.getThen()))) .collect(Collectors.toList()); @@ -169,8 +171,8 @@ public Expression from(io.substrait.proto.Expression expr) { } case SINGULAR_OR_LIST: { - io.substrait.proto.Expression.SingularOrList orList = expr.getSingularOrList(); - List values = + final io.substrait.proto.Expression.SingularOrList orList = expr.getSingularOrList(); + final List values = orList.getOptionsList().stream().map(this::from).collect(Collectors.toList()); return Expression.SingleOrList.builder() .condition(from(orList.getValue())) @@ -179,8 +181,8 @@ public Expression from(io.substrait.proto.Expression expr) { } case MULTI_OR_LIST: { - io.substrait.proto.Expression.MultiOrList multiOrList = expr.getMultiOrList(); - List values = + final io.substrait.proto.Expression.MultiOrList multiOrList = expr.getMultiOrList(); + final List values = multiOrList.getOptionsList().stream() .map( t -> @@ -207,7 +209,7 @@ public Expression from(io.substrait.proto.Expression expr) { switch (expr.getSubquery().getSubqueryTypeCase()) { case SET_PREDICATE: { - io.substrait.relation.Rel rel = + final io.substrait.relation.Rel rel = protoRelConverter.from(expr.getSubquery().getSetPredicate().getTuples()); return Expression.SetPredicate.builder() .tuples(rel) @@ -218,7 +220,7 @@ public Expression from(io.substrait.proto.Expression expr) { } case SCALAR: { - io.substrait.relation.Rel rel = + final io.substrait.relation.Rel rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); return Expression.ScalarSubquery.builder() .input(rel) @@ -228,7 +230,8 @@ public Expression from(io.substrait.proto.Expression expr) { new TypeVisitor.TypeThrowsVisitor( "Expected struct field") { @Override - public Type visit(Type.Struct type) throws RuntimeException { + public Type visit(final Type.Struct type) + throws RuntimeException { if (type.fields().size() != 1) { throw new UnsupportedOperationException( "Scalar subquery must have exactly one field"); @@ -241,9 +244,9 @@ public Type visit(Type.Struct type) throws RuntimeException { } case IN_PREDICATE: { - io.substrait.relation.Rel rel = + final io.substrait.relation.Rel rel = protoRelConverter.from(expr.getSubquery().getInPredicate().getHaystack()); - List needles = + final List needles = expr.getSubquery().getInPredicate().getNeedlesList().stream() .map(e -> this.from(e)) .collect(Collectors.toList()); @@ -267,31 +270,31 @@ public Type visit(Type.Struct type) throws RuntimeException { } public Expression.WindowFunctionInvocation fromWindowFunction( - io.substrait.proto.Expression.WindowFunction windowFunction) { - int functionReference = windowFunction.getFunctionReference(); - SimpleExtension.WindowFunctionVariant declaration = + final io.substrait.proto.Expression.WindowFunction windowFunction) { + final int functionReference = windowFunction.getFunctionReference(); + final SimpleExtension.WindowFunctionVariant declaration = lookup.getWindowFunction(functionReference, extensions); - FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); + final FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); - List args = + final List args = fromFunctionArgumentList( windowFunction.getArgumentsCount(), argVisitor, declaration, windowFunction::getArguments); - List partitionExprs = + final List partitionExprs = windowFunction.getPartitionsList().stream().map(this::from).collect(Collectors.toList()); - List sortFields = + final List sortFields = windowFunction.getSortsList().stream() .map(this::fromSortField) .collect(Collectors.toList()); - List options = + final List options = windowFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); - WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound()); - WindowBound upperBound = toWindowBound(windowFunction.getUpperBound()); + final WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound()); + final WindowBound upperBound = toWindowBound(windowFunction.getUpperBound()); return Expression.WindowFunctionInvocation.builder() .arguments(args) @@ -309,25 +312,25 @@ public Expression.WindowFunctionInvocation fromWindowFunction( } public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFunction( - ConsistentPartitionWindowRel.WindowRelFunction windowRelFunction) { - int functionReference = windowRelFunction.getFunctionReference(); - SimpleExtension.WindowFunctionVariant declaration = + final ConsistentPartitionWindowRel.WindowRelFunction windowRelFunction) { + final int functionReference = windowRelFunction.getFunctionReference(); + final SimpleExtension.WindowFunctionVariant declaration = lookup.getWindowFunction(functionReference, extensions); - FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); + final FunctionArg.ProtoFrom argVisitor = new FunctionArg.ProtoFrom(this, protoTypeConverter); - List args = + final List args = fromFunctionArgumentList( windowRelFunction.getArgumentsCount(), argVisitor, declaration, windowRelFunction::getArguments); - List options = + final List options = windowRelFunction.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); - WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound()); - WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound()); + final WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound()); + final WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound()); return ConsistentPartitionWindow.WindowRelFunctionInvocation.builder() .arguments(args) @@ -342,7 +345,8 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti .build(); } - private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.Bound bound) { + private WindowBound toWindowBound( + final io.substrait.proto.Expression.WindowFunction.Bound bound) { switch (bound.getKindCase()) { case PRECEDING: return WindowBound.Preceding.of(bound.getPreceding().getOffset()); @@ -361,7 +365,7 @@ private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.B } } - public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { + public Expression.Literal from(final io.substrait.proto.Expression.Literal literal) { switch (literal.getLiteralTypeCase()) { case BOOLEAN: return ExpressionCreator.bool(literal.getNullable(), literal.getBoolean()); @@ -408,11 +412,11 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { // Handle deprecated version that doesn't provide precision and that uses microseconds // instead of subseconds, for backwards compatibility - int precision = + final int precision = literal.getIntervalDayToSecond().hasPrecision() ? literal.getIntervalDayToSecond().getPrecision() : 6; // microseconds - long subseconds = + final long subseconds = literal.getIntervalDayToSecond().hasPrecision() ? literal.getIntervalDayToSecond().getSubseconds() : literal.getIntervalDayToSecond().getMicroseconds(); @@ -468,7 +472,7 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { // literal.getNullable() is intentionally ignored in favor of the nullability // specified in the literal.getEmptyMap() type. - Type.Map mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); + final Type.Map mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); return ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value()); } case UUID: @@ -485,14 +489,14 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { // literal.getNullable() is intentionally ignored in favor of the nullability // specified in the literal.getEmptyList() type. - Type.ListType listType = protoTypeConverter.fromList(literal.getEmptyList()); + final Type.ListType listType = protoTypeConverter.fromList(literal.getEmptyList()); return ExpressionCreator.emptyList(listType.nullable(), listType.elementType()); } case USER_DEFINED: { - io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = + final io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined(); - SimpleExtension.Type type = + final SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); return ExpressionCreator.userDefinedLiteral( literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); @@ -503,23 +507,23 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { } private static List fromFunctionArgumentList( - int argumentsCount, - FunctionArg.ProtoFrom argVisitor, - SimpleExtension.Function declaration, - Function argFunction) { + final int argumentsCount, + final FunctionArg.ProtoFrom argVisitor, + final SimpleExtension.Function declaration, + final Function argFunction) { return IntStream.range(0, argumentsCount) .mapToObj(i -> argVisitor.convert(declaration, i, argFunction.apply(i))) .collect(Collectors.toList()); } - public Expression.SortField fromSortField(SortField s) { + public Expression.SortField fromSortField(final SortField s) { return Expression.SortField.builder() .direction(Expression.SortDirection.fromProto(s.getDirection())) .expr(from(s.getExpr())) .build(); } - public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { + public static FunctionOption fromFunctionOption(final io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } } diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index e0cb9d1c5..df70744a2 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -15,24 +15,24 @@ public class ExtendedExpressionProtoConverter { public ExtendedExpression toProto( - io.substrait.extendedexpression.ExtendedExpression extendedExpression) { + final io.substrait.extendedexpression.ExtendedExpression extendedExpression) { - ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); - ExtensionCollector functionCollector = new ExtensionCollector(); + final ExtendedExpression.Builder builder = ExtendedExpression.newBuilder(); + final ExtensionCollector functionCollector = new ExtensionCollector(); final ExpressionProtoConverter expressionProtoConverter = new ExpressionProtoConverter(functionCollector, null); - for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReferenceBase + for (final io.substrait.extendedexpression.ExtendedExpression.ExpressionReferenceBase expressionReference : extendedExpression.getReferredExpressions()) { if (expressionReference instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference) { - io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et = + final io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et = (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference) expressionReference; - io.substrait.proto.Expression expressionProto = + final io.substrait.proto.Expression expressionProto = et.getExpression().accept(expressionProtoConverter, EmptyVisitationContext.INSTANCE); - ExpressionReference.Builder expressionReferenceBuilder = + final ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) .addAllOutputNames(expressionReference.getOutputNames()); @@ -40,10 +40,10 @@ public ExtendedExpression toProto( } else if (expressionReference instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference) { - io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference aft = + final io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference aft = (io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionReference) expressionReference; - ExpressionReference.Builder expressionReferenceBuilder = + final ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setMeasure( new AggregateFunctionProtoConverter(functionCollector) diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index 175b2d705..8294d60ab 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -27,37 +27,38 @@ public ProtoExtendedExpressionConverter() { this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } - public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) { + public ProtoExtendedExpressionConverter( + final SimpleExtension.ExtensionCollection extensionCollection) { if (extensionCollection == null) { throw new IllegalArgumentException("ExtensionCollection is required"); } this.extensionCollection = extensionCollection; } - public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) { + public ExtendedExpression from(final io.substrait.proto.ExtendedExpression extendedExpression) { // fill in simple extension information through a discovery in the current proto-extended // expression - ExtensionLookup functionLookup = + final ExtensionLookup functionLookup = ImmutableExtensionLookup.builder(extensionCollection).from(extendedExpression).build(); - NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); + final NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); - io.substrait.type.NamedStruct namedStruct = + final io.substrait.type.NamedStruct namedStruct = io.substrait.type.NamedStruct.fromProto(baseSchemaProto, protoTypeConverter); - ProtoExpressionConverter protoExpressionConverter = + final ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( functionLookup, this.extensionCollection, namedStruct.struct(), null); - List expressionReferences = new ArrayList<>(); + final List expressionReferences = new ArrayList<>(); - for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { + for (final ExpressionReference expressionReference : extendedExpression.getReferredExprList()) { switch (expressionReference.getExprTypeCase()) { case EXPRESSION: - Expression expressionPojo = + final Expression expressionPojo = protoExpressionConverter.from(expressionReference.getExpression()); - ImmutableExpressionReference buildExpression = + final ImmutableExpressionReference buildExpression = ImmutableExpressionReference.builder() .expression(expressionPojo) .addAllOutputNames(expressionReference.getOutputNamesList()) @@ -65,14 +66,14 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp expressionReferences.add(buildExpression); break; case MEASURE: - io.substrait.relation.Aggregate.Measure measure = + final io.substrait.relation.Aggregate.Measure measure = io.substrait.relation.Aggregate.Measure.builder() .function( new ProtoAggregateFunctionConverter( functionLookup, extensionCollection, protoExpressionConverter) .from(expressionReference.getMeasure())) .build(); - ImmutableAggregateFunctionReference buildMeasure = + final ImmutableAggregateFunctionReference buildMeasure = ImmutableAggregateFunctionReference.builder() .measure(measure) .addAllOutputNames(expressionReference.getOutputNamesList()) @@ -88,7 +89,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp } } - ImmutableExtendedExpression.Builder builder = + final ImmutableExtendedExpression.Builder builder = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) .advancedExtension( diff --git a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java index 16e41f03f..5d29e2105 100644 --- a/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/AbstractExtensionLookup.java @@ -7,16 +7,16 @@ public abstract class AbstractExtensionLookup implements ExtensionLookup { protected final Map typeAnchorMap; public AbstractExtensionLookup( - Map functionAnchorMap, - Map typeAnchorMap) { + final Map functionAnchorMap, + final Map typeAnchorMap) { this.functionAnchorMap = functionAnchorMap; this.typeAnchorMap = typeAnchorMap; } @Override public SimpleExtension.ScalarFunctionVariant getScalarFunction( - int reference, SimpleExtension.ExtensionCollection extensions) { - SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); + final int reference, final SimpleExtension.ExtensionCollection extensions) { + final SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -27,8 +27,8 @@ public SimpleExtension.ScalarFunctionVariant getScalarFunction( @Override public SimpleExtension.WindowFunctionVariant getWindowFunction( - int reference, SimpleExtension.ExtensionCollection extensions) { - SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); + final int reference, final SimpleExtension.ExtensionCollection extensions) { + final SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -39,8 +39,8 @@ public SimpleExtension.WindowFunctionVariant getWindowFunction( @Override public SimpleExtension.AggregateFunctionVariant getAggregateFunction( - int reference, SimpleExtension.ExtensionCollection extensions) { - SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); + final int reference, final SimpleExtension.ExtensionCollection extensions) { + final SimpleExtension.FunctionAnchor anchor = functionAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown function id. Make sure that the function id provided was shared in the extensions section of the plan."); @@ -51,8 +51,8 @@ public SimpleExtension.AggregateFunctionVariant getAggregateFunction( @Override public SimpleExtension.Type getType( - int reference, SimpleExtension.ExtensionCollection extensions) { - SimpleExtension.TypeAnchor anchor = typeAnchorMap.get(reference); + final int reference, final SimpleExtension.ExtensionCollection extensions) { + final SimpleExtension.TypeAnchor anchor = typeAnchorMap.get(reference); if (anchor == null) { throw new IllegalArgumentException( "Unknown type id. Make sure that the type id provided was shared in the extensions section of the plan."); diff --git a/core/src/main/java/io/substrait/extension/BidiMap.java b/core/src/main/java/io/substrait/extension/BidiMap.java index f0eeb30a7..3a20f0d64 100644 --- a/core/src/main/java/io/substrait/extension/BidiMap.java +++ b/core/src/main/java/io/substrait/extension/BidiMap.java @@ -9,10 +9,10 @@ public class BidiMap { private final Map forwardMap; private final Map reverseMap; - BidiMap(Map forwardMap) { + BidiMap(final Map forwardMap) { this.forwardMap = forwardMap; this.reverseMap = new HashMap<>(); - for (Map.Entry entry : forwardMap.entrySet()) { + for (final Map.Entry entry : forwardMap.entrySet()) { reverseMap.put(entry.getValue(), entry.getKey()); } } @@ -22,11 +22,11 @@ public class BidiMap { this.reverseMap = new HashMap<>(); } - T2 get(T1 t1) { + T2 get(final T1 t1) { return forwardMap.get(t1); } - T1 reverseGet(T2 t2) { + T1 reverseGet(final T2 t2) { return reverseMap.get(t2); } @@ -34,9 +34,9 @@ T1 reverseGet(T2 t2) { * Associates the specified values in both directions. Throws if either value is already mapped to * a different value. */ - void put(T1 t1, T2 t2) { - T2 existingForward = forwardMap.get(t1); - T1 existingReverse = reverseMap.get(t2); + void put(final T1 t1, final T2 t2) { + final T2 existingForward = forwardMap.get(t1); + final T1 existingReverse = reverseMap.get(t2); if (existingForward != null && !existingForward.equals(t2)) { throw new IllegalArgumentException("Key already exists in map with different value"); @@ -49,8 +49,8 @@ void put(T1 t1, T2 t2) { reverseMap.put(t2, t1); } - void merge(BidiMap other) { - for (Map.Entry entry : other.forwardEntrySet()) { + void merge(final BidiMap other) { + for (final Map.Entry entry : other.forwardEntrySet()) { put(entry.getKey(), entry.getValue()); } } diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 89aad954e..f16bf7865 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -27,7 +27,7 @@ public class DefaultExtensionCatalog { loadDefaultCollection(); private static SimpleExtension.ExtensionCollection loadDefaultCollection() { - List defaultFiles = + final List defaultFiles = Arrays.asList( "boolean", "aggregate_generic", diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index 7ad07a6b1..f8010cb46 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -25,7 +25,7 @@ public class ExtensionCollector extends AbstractExtensionLookup { // start at 0 to make sure functionAnchors start with 1 according to spec private int counter = 0; - private String getUriFromUrn(String urn) { + private String getUriFromUrn(final String urn) { return extensionCollection.getUriFromUrn(urn); } @@ -33,7 +33,7 @@ public ExtensionCollector() { this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } - public ExtensionCollector(SimpleExtension.ExtensionCollection extensionCollection) { + public ExtensionCollector(final SimpleExtension.ExtensionCollection extensionCollection) { super(new HashMap<>(), new HashMap<>()); if (extensionCollection == null) { throw new IllegalArgumentException("ExtensionCollection is required"); @@ -43,8 +43,8 @@ public ExtensionCollector(SimpleExtension.ExtensionCollection extensionCollectio this.extensionCollection = extensionCollection; } - public int getFunctionReference(SimpleExtension.Function declaration) { - Integer i = funcMap.reverseGet(declaration.getAnchor()); + public int getFunctionReference(final SimpleExtension.Function declaration) { + final Integer i = funcMap.reverseGet(declaration.getAnchor()); if (i != null) { return i; } @@ -53,8 +53,8 @@ public int getFunctionReference(SimpleExtension.Function declaration) { return counter; } - public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { - Integer i = typeMap.reverseGet(typeAnchor); + public int getTypeReference(final SimpleExtension.TypeAnchor typeAnchor) { + final Integer i = typeMap.reverseGet(typeAnchor); if (i != null) { return i; } @@ -63,16 +63,16 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { return counter; } - public void addExtensionsToPlan(Plan.Builder builder) { - SimpleExtensions simpleExtensions = getExtensions(); + public void addExtensionsToPlan(final Plan.Builder builder) { + final SimpleExtensions simpleExtensions = getExtensions(); builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensionUris(simpleExtensions.uris.values()); builder.addAllExtensions(simpleExtensions.extensionList); } - public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { - SimpleExtensions simpleExtensions = getExtensions(); + public void addExtensionsToExtendedExpression(final ExtendedExpression.Builder builder) { + final SimpleExtensions simpleExtensions = getExtensions(); builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensionUris(simpleExtensions.uris.values()); @@ -80,18 +80,18 @@ public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder } private SimpleExtensions getExtensions() { - AtomicInteger urnPos = new AtomicInteger(1); - AtomicInteger uriPos = new AtomicInteger(1); - HashMap urns = new HashMap<>(); - HashMap uris = new HashMap<>(); + final AtomicInteger urnPos = new AtomicInteger(1); + final AtomicInteger uriPos = new AtomicInteger(1); + final HashMap urns = new HashMap<>(); + final HashMap uris = new HashMap<>(); - ArrayList extensionList = new ArrayList<>(); - for (Map.Entry e : funcMap.forwardEntrySet()) { - String urn = e.getValue().urn(); - String uri = getUriFromUrn(urn); + final ArrayList extensionList = new ArrayList<>(); + for (final Map.Entry e : funcMap.forwardEntrySet()) { + final String urn = e.getValue().urn(); + final String uri = getUriFromUrn(urn); // Create URN entry - SimpleExtensionURN urnObj = + final SimpleExtensionURN urnObj = urns.computeIfAbsent( urn, k -> @@ -114,7 +114,7 @@ private SimpleExtensions getExtensions() { } // Create function declaration with both URN and URI references - SimpleExtensionDeclaration.ExtensionFunction.Builder funcBuilder = + final SimpleExtensionDeclaration.ExtensionFunction.Builder funcBuilder = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(e.getKey()) .setName(e.getValue().key()) @@ -124,17 +124,17 @@ private SimpleExtensions getExtensions() { funcBuilder.setExtensionUriReference(uriObj.getExtensionUriAnchor()); } - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(funcBuilder).build(); extensionList.add(decl); } - for (Map.Entry e : typeMap.forwardEntrySet()) { - String urn = e.getValue().urn(); - String uri = getUriFromUrn(urn); + for (final Map.Entry e : typeMap.forwardEntrySet()) { + final String urn = e.getValue().urn(); + final String uri = getUriFromUrn(urn); // Create URN entry - SimpleExtensionURN urnObj = + final SimpleExtensionURN urnObj = urns.computeIfAbsent( urn, k -> @@ -157,7 +157,7 @@ private SimpleExtensions getExtensions() { } // Create type declaration with both URN and URI references - SimpleExtensionDeclaration.ExtensionType.Builder typeBuilder = + final SimpleExtensionDeclaration.ExtensionType.Builder typeBuilder = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(e.getKey()) .setName(e.getValue().key()) @@ -167,7 +167,7 @@ private SimpleExtensions getExtensions() { typeBuilder.setExtensionUriReference(uriObj.getExtensionUriAnchor()); } - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(typeBuilder).build(); extensionList.add(decl); } @@ -180,9 +180,9 @@ private static final class SimpleExtensions { final ArrayList extensionList; SimpleExtensions( - HashMap urns, - HashMap uris, - ArrayList extensionList) { + final HashMap urns, + final HashMap uris, + final ArrayList extensionList) { this.urns = urns; this.uris = uris; this.extensionList = extensionList; diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 9b546d9dc..72b3adaf4 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -16,8 +16,8 @@ public class ImmutableExtensionLookup extends AbstractExtensionLookup { private ImmutableExtensionLookup( - Map functionMap, - Map typeMap) { + final Map functionMap, + final Map typeMap) { super(functionMap, typeMap); } @@ -25,7 +25,7 @@ public static Builder builder() { return builder(DefaultExtensionCatalog.DEFAULT_COLLECTION); } - public static Builder builder(SimpleExtension.ExtensionCollection extensionCollection) { + public static Builder builder(final SimpleExtension.ExtensionCollection extensionCollection) { return new Builder(extensionCollection); } @@ -34,7 +34,7 @@ public static class Builder { private final Map typeMap = new HashMap<>(); private final SimpleExtension.ExtensionCollection extensionCollection; - public Builder(SimpleExtension.ExtensionCollection extensionCollection) { + public Builder(final SimpleExtension.ExtensionCollection extensionCollection) { if (extensionCollection == null) { throw new IllegalArgumentException("ExtensionCollection is required"); } @@ -47,18 +47,18 @@ public Builder(SimpleExtension.ExtensionCollection extensionCollection) { * @param uri The URI to resolve * @return The corresponding URN, or null if no mapping exists */ - private String resolveUrnFromUri(String uri) { + private String resolveUrnFromUri(final String uri) { return extensionCollection.getUrnFromUri(uri); } private SimpleExtension.FunctionAnchor resolveFunctionAnchor( - SimpleExtensionDeclaration.ExtensionFunction func, - Map urnMap, - Map uriMap) { + final SimpleExtensionDeclaration.ExtensionFunction func, + final Map urnMap, + final Map uriMap) { // 1. Try non-zero URN reference if (func.getExtensionUrnReference() != 0) { - String urnFromUrnRef = urnMap.get(func.getExtensionUrnReference()); + final String urnFromUrnRef = urnMap.get(func.getExtensionUrnReference()); if (urnFromUrnRef == null) { throw new IllegalStateException( String.format( @@ -70,14 +70,14 @@ private SimpleExtension.FunctionAnchor resolveFunctionAnchor( // 2. Try non-zero URI reference if (func.getExtensionUriReference() != 0) { - String uriFromUriRef = uriMap.get(func.getExtensionUriReference()); + final String uriFromUriRef = uriMap.get(func.getExtensionUriReference()); if (uriFromUriRef == null) { throw new IllegalStateException( String.format( "Function '%s' references URI anchor %d, but no URI is registered at that anchor", func.getName(), func.getExtensionUriReference())); } - String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + final String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); if (urnFromUriRef == null) { throw new IllegalStateException( String.format( @@ -94,12 +94,12 @@ private SimpleExtension.FunctionAnchor resolveFunctionAnchor( * We perform some additional checks to below to handle this. */ - String urn = urnMap.get(func.getExtensionUrnReference()); - String uri = uriMap.get(func.getExtensionUriReference()); + final String urn = urnMap.get(func.getExtensionUrnReference()); + final String uri = uriMap.get(func.getExtensionUriReference()); // 3. Try both 0 URI and 0 URN if both resolve if (uri != null && urn != null) { - String resolvedUrn = resolveUrnFromUri(uri); + final String resolvedUrn = resolveUrnFromUri(uri); if (urn.equals(resolvedUrn)) { return SimpleExtension.FunctionAnchor.of(urn, func.getName()); } @@ -125,13 +125,13 @@ private SimpleExtension.FunctionAnchor resolveFunctionAnchor( } private SimpleExtension.TypeAnchor resolveTypeAnchor( - SimpleExtensionDeclaration.ExtensionType type, - Map urnMap, - Map uriMap) { + final SimpleExtensionDeclaration.ExtensionType type, + final Map urnMap, + final Map uriMap) { // 1. Try non-zero URN reference if (type.getExtensionUrnReference() != 0) { - String urnFromUrnRef = urnMap.get(type.getExtensionUrnReference()); + final String urnFromUrnRef = urnMap.get(type.getExtensionUrnReference()); if (urnFromUrnRef == null) { throw new IllegalStateException( String.format( @@ -143,14 +143,14 @@ private SimpleExtension.TypeAnchor resolveTypeAnchor( // 2. Try non-zero URI reference if (type.getExtensionUriReference() != 0) { - String uriFromUriRef = uriMap.get(type.getExtensionUriReference()); + final String uriFromUriRef = uriMap.get(type.getExtensionUriReference()); if (uriFromUriRef == null) { throw new IllegalStateException( String.format( "Type '%s' references URI anchor %d, but no URI is registered at that anchor", type.getName(), type.getExtensionUriReference())); } - String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); + final String urnFromUriRef = resolveUrnFromUri(uriFromUriRef); if (urnFromUriRef == null) { throw new IllegalStateException( String.format( @@ -167,12 +167,12 @@ private SimpleExtension.TypeAnchor resolveTypeAnchor( * We perform some additional checks to below to handle this. */ - String urn = urnMap.get(type.getExtensionUrnReference()); - String uri = uriMap.get(type.getExtensionUriReference()); + final String urn = urnMap.get(type.getExtensionUrnReference()); + final String uri = uriMap.get(type.getExtensionUriReference()); // 3. Try both 0 URI and 0 URN if both resolve if (uri != null && urn != null) { - String resolvedUrn = resolveUrnFromUri(uri); + final String resolvedUrn = resolveUrnFromUri(uri); if (urn.equals(resolvedUrn)) { return SimpleExtension.TypeAnchor.of(urn, type.getName()); } @@ -197,12 +197,12 @@ private SimpleExtension.TypeAnchor resolveTypeAnchor( uri, urn)); } - public Builder from(Plan plan) { + public Builder from(final Plan plan) { return from( plan.getExtensionUrnsList(), plan.getExtensionUrisList(), plan.getExtensionsList()); } - public Builder from(ExtendedExpression extendedExpression) { + public Builder from(final ExtendedExpression extendedExpression) { return from( extendedExpression.getExtensionUrnsList(), extendedExpression.getExtensionUrisList(), @@ -210,41 +210,41 @@ public Builder from(ExtendedExpression extendedExpression) { } private Builder from( - List simpleExtensionURNs, - List simpleExtensionURIs, - List simpleExtensionDeclarations) { - Map urnMap = new HashMap<>(); - Map uriMap = new HashMap<>(); + final List simpleExtensionURNs, + final List simpleExtensionURIs, + final List simpleExtensionDeclarations) { + final Map urnMap = new HashMap<>(); + final Map uriMap = new HashMap<>(); // Handle URN format - for (SimpleExtensionURN extension : simpleExtensionURNs) { + for (final SimpleExtensionURN extension : simpleExtensionURNs) { urnMap.put(extension.getExtensionUrnAnchor(), extension.getUrn()); } // Handle deprecated URI format - for (io.substrait.proto.SimpleExtensionURI extension : simpleExtensionURIs) { + for (final io.substrait.proto.SimpleExtensionURI extension : simpleExtensionURIs) { uriMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap - for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { + for (final SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionFunction()) { continue; } - SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); - int reference = func.getFunctionAnchor(); - SimpleExtension.FunctionAnchor anchor = resolveFunctionAnchor(func, urnMap, uriMap); + final SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); + final int reference = func.getFunctionAnchor(); + final SimpleExtension.FunctionAnchor anchor = resolveFunctionAnchor(func, urnMap, uriMap); functionMap.put(reference, anchor); } // Add all types used in plan to the typeMap - for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { + for (final SimpleExtensionDeclaration extension : simpleExtensionDeclarations) { if (!extension.hasExtensionType()) { continue; } - SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); - int reference = type.getTypeAnchor(); - SimpleExtension.TypeAnchor anchor = resolveTypeAnchor(type, urnMap, uriMap); + final SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); + final int reference = type.getTypeAnchor(); + final SimpleExtension.TypeAnchor anchor = resolveTypeAnchor(type, urnMap, uriMap); typeMap.put(reference, anchor); } diff --git a/core/src/main/java/io/substrait/extension/ProtoExtensionConverter.java b/core/src/main/java/io/substrait/extension/ProtoExtensionConverter.java index d017a912e..1d38574d8 100644 --- a/core/src/main/java/io/substrait/extension/ProtoExtensionConverter.java +++ b/core/src/main/java/io/substrait/extension/ProtoExtensionConverter.java @@ -41,7 +41,7 @@ public AdvancedExtension fromProto(final io.substrait.proto.@NonNull AdvancedExt * @return the converted {@link AdvancedExtension.Optimization} */ protected AdvancedExtension.Optimization optimizationFromAdvancedExtension( - com.google.protobuf.@NonNull Any any) { + final com.google.protobuf.@NonNull Any any) { throw new UnsupportedOperationException( "missing deserialization logic for AdvancedExtension.Optimization"); } @@ -57,7 +57,7 @@ protected AdvancedExtension.Optimization optimizationFromAdvancedExtension( * @return the converted {@link AdvancedExtension.Enhancement} */ protected AdvancedExtension.Enhancement enhancementFromAdvancedExtension( - com.google.protobuf.@NonNull Any any) { + final com.google.protobuf.@NonNull Any any) { throw new UnsupportedOperationException( "missing deserialization logic for AdvancedExtension.Enhancement"); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index 39d7c45e0..2d3efd0db 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -53,7 +53,7 @@ public class SimpleExtension { // `\A` means beginning of input. Using it as a delimiter in a scanner reads in the whole file. private static Pattern READ_WHOLE_FILE = Pattern.compile("\\A"); - private static void validateUrn(String urn) { + private static void validateUrn(final String urn) { if (urn == null || urn.trim().isEmpty()) { throw new IllegalArgumentException("URN cannot be null or empty"); } @@ -63,8 +63,8 @@ private static void validateUrn(String urn) { } } - private static ObjectMapper objectMapper(String urn) { - InjectableValues.Std iv = new InjectableValues.Std(); + private static ObjectMapper objectMapper(final String urn) { + final InjectableValues.Std iv = new InjectableValues.Std(); iv.addValue(URN_LOCATOR_KEY, urn); return new ObjectMapper(new YAMLFactory()) @@ -210,14 +210,14 @@ public interface Anchor { @Value.Immutable public interface FunctionAnchor extends Anchor { - static FunctionAnchor of(String urn, String key) { + static FunctionAnchor of(final String urn, final String key) { return ImmutableSimpleExtension.FunctionAnchor.builder().urn(urn).key(key).build(); } } @Value.Immutable public interface TypeAnchor extends Anchor { - static TypeAnchor of(String urn, String name) { + static TypeAnchor of(final String urn, final String name) { return ImmutableSimpleExtension.TypeAnchor.builder().urn(urn).key(name).build(); } } @@ -300,7 +300,7 @@ public FunctionAnchor getAnchor() { public abstract TypeExpression returnType(); public static String constructKeyFromTypes( - String name, List arguments) { + final String name, final List arguments) { try { return name + ":" @@ -313,7 +313,7 @@ public static String constructKeyFromTypes( } } - public static String constructKey(String name, List arguments) { + public static String constructKey(final String name, final List arguments) { try { return name + ":" @@ -326,12 +326,12 @@ public static String constructKey(String name, List arguments) { public Util.IntRange getRange() { // end range is exclusive so add one to size. - int max = + final int max = variadic() .map( t -> { - OptionalInt optionalMax = t.getMax(); - IntStream stream = + final OptionalInt optionalMax = t.getMax(); + final IntStream stream = optionalMax.isPresent() ? IntStream.of(optionalMax.getAsInt()) : IntStream.empty(); @@ -341,13 +341,13 @@ public Util.IntRange getRange() { .orElse(Integer.MAX_VALUE); }) .orElse(args().size() + 1); - int min = + final int min = variadic().map(t -> args().size() - 1 + t.getMin()).orElse(requiredArguments().size()); return Util.IntRange.of(min, max); } public void validateOutputType( - List argumentExpressions, io.substrait.type.Type outputType) { + final List argumentExpressions, final io.substrait.type.Type outputType) { // TODO: support advanced output type validation using return expressions, parameters, etc. // The code below was too restrictive in the case of nullability conversion. return; @@ -366,7 +366,7 @@ public String key() { return keySupplier.get(); } - public io.substrait.type.Type resolveType(List argumentTypes) { + public io.substrait.type.Type resolveType(final List argumentTypes) { return TypeExpressionEvaluator.evaluateExpression(returnType(), args(), argumentTypes); } } @@ -382,7 +382,7 @@ public abstract static class ScalarFunction { public abstract List impls(); - public Stream resolve(String urn) { + public Stream resolve(final String urn) { return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -391,7 +391,8 @@ public Stream resolve(String urn) { @JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class) @Value.Immutable public abstract static class ScalarFunctionVariant extends Function { - public ScalarFunctionVariant resolve(String urn, String name, String description) { + public ScalarFunctionVariant resolve( + final String urn, final String name, final String description) { return ImmutableSimpleExtension.ScalarFunctionVariant.builder() .urn(urn) .name(name) @@ -418,7 +419,7 @@ public abstract static class AggregateFunction { public abstract List impls(); - public Stream resolve(String urn) { + public Stream resolve(final String urn) { return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -435,7 +436,7 @@ public abstract static class WindowFunction { public abstract List impls(); - public Stream resolve(String urn) { + public Stream resolve(final String urn) { return impls().stream().map(f -> f.resolve(urn, name(), description())); } @@ -462,7 +463,8 @@ public String toString() { @Nullable public abstract TypeExpression intermediate(); - AggregateFunctionVariant resolve(String urn, String name, String description) { + AggregateFunctionVariant resolve( + final String urn, final String name, final String description) { return ImmutableSimpleExtension.AggregateFunctionVariant.builder() .urn(urn) .name(name) @@ -504,7 +506,7 @@ public String toString() { return super.toString(); } - WindowFunctionVariant resolve(String urn, String name, String description) { + WindowFunctionVariant resolve(final String urn, final String name, final String description) { return ImmutableSimpleExtension.WindowFunctionVariant.builder() .urn(urn) .name(name) @@ -580,7 +582,7 @@ public int size() { + (windows() == null ? 0 : windows().size()); } - public Stream resolve(String urn) { + public Stream resolve(final String urn) { return Stream.concat( Stream.concat( scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(urn)), @@ -654,8 +656,8 @@ public static ImmutableSimpleExtension.ExtensionCollection.Builder builder() { return ImmutableSimpleExtension.ExtensionCollection.builder(); } - public Type getType(TypeAnchor anchor) { - Type type = typeLookup.get().get(anchor); + public Type getType(final TypeAnchor anchor) { + final Type type = typeLookup.get().get(anchor); if (type != null) { return type; } @@ -666,8 +668,8 @@ public Type getType(TypeAnchor anchor) { anchor.key(), anchor.urn())); } - public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { - ScalarFunctionVariant variant = scalarFunctionsLookup.get().get(anchor); + public ScalarFunctionVariant getScalarFunction(final FunctionAnchor anchor) { + final ScalarFunctionVariant variant = scalarFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -679,7 +681,7 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } - private void checkUrn(String name) { + private void checkUrn(final String name) { if (urnSupplier.get().contains(name)) { return; } @@ -691,8 +693,8 @@ private void checkUrn(String name) { name)); } - public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { - AggregateFunctionVariant variant = aggregateFunctionsLookup.get().get(anchor); + public AggregateFunctionVariant getAggregateFunction(final FunctionAnchor anchor) { + final AggregateFunctionVariant variant = aggregateFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -705,8 +707,8 @@ public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } - public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { - WindowFunctionVariant variant = windowFunctionsLookup.get().get(anchor); + public WindowFunctionVariant getWindowFunction(final FunctionAnchor anchor) { + final WindowFunctionVariant variant = windowFunctionsLookup.get().get(anchor); if (variant != null) { return variant; } @@ -724,7 +726,7 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { * @param urn The URN to look up * @return The corresponding URI, or null if not found */ - String getUriFromUrn(String urn) { + String getUriFromUrn(final String urn) { return uriUrnMap().reverseGet(urn); } @@ -734,12 +736,12 @@ String getUriFromUrn(String urn) { * @param uri The URI to look up * @return The corresponding URN, or null if not found */ - String getUrnFromUri(String uri) { + String getUrnFromUri(final String uri) { return uriUrnMap().get(uri); } - public ExtensionCollection merge(ExtensionCollection extensionCollection) { - BidiMap mergedUriUrnMap = new BidiMap<>(); + public ExtensionCollection merge(final ExtensionCollection extensionCollection) { + final BidiMap mergedUriUrnMap = new BidiMap<>(); mergedUriUrnMap.merge(uriUrnMap()); mergedUriUrnMap.merge(extensionCollection.uriUrnMap()); @@ -757,12 +759,12 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { } } - public static ExtensionCollection load(List resourcePaths) { + public static ExtensionCollection load(final List resourcePaths) { if (resourcePaths.isEmpty()) { throw new IllegalArgumentException("Require at least one resource path."); } - List extensions = + final List extensions = resourcePaths.stream() .map( path -> { @@ -780,26 +782,26 @@ public static ExtensionCollection load(List resourcePaths) { return complete; } - public static ExtensionCollection load(String uri, String content) { + public static ExtensionCollection load(final String uri, final String content) { try { if (uri == null || uri.isEmpty()) { throw new IllegalArgumentException("URI cannot be null or empty"); } // Parse with basic YAML mapper first to extract URN - ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory()); - com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content); - com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn"); + final ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory()); + final com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content); + final com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn"); if (urnNode == null) { throw new IllegalArgumentException("Extension YAML file must contain a 'urn' field"); } - String urn = urnNode.asText(); + final String urn = urnNode.asText(); validateUrn(urn); - ExtensionSignatures docWithoutUri = + final ExtensionSignatures docWithoutUri = objectMapper(urn).readValue(content, ExtensionSignatures.class); - ExtensionSignatures doc = + final ExtensionSignatures doc = ImmutableSimpleExtension.ExtensionSignatures.builder().from(docWithoutUri).build(); return buildExtensionCollection(uri, doc); @@ -808,36 +810,36 @@ public static ExtensionCollection load(String uri, String content) { } } - public static ExtensionCollection load(String uri, InputStream stream) { + public static ExtensionCollection load(final String uri, final InputStream stream) { try (Scanner scanner = new Scanner(stream)) { scanner.useDelimiter(READ_WHOLE_FILE); - String content = scanner.next(); + final String content = scanner.next(); return load(uri, content); } } public static ExtensionCollection buildExtensionCollection( - String uri, ExtensionSignatures extensionSignatures) { - String urn = extensionSignatures.urn(); + final String uri, final ExtensionSignatures extensionSignatures) { + final String urn = extensionSignatures.urn(); validateUrn(urn); if (uri == null || uri == "") { throw new IllegalArgumentException("URI cannot be null or empty"); } - List scalarFunctionVariants = + final List scalarFunctionVariants = extensionSignatures.scalars().stream() .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); - List aggregateFunctionVariants = + final List aggregateFunctionVariants = extensionSignatures.aggregates().stream() .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); - Stream windowFunctionVariants = + final Stream windowFunctionVariants = extensionSignatures.windows().stream().flatMap(t -> t.resolve(urn)); // Aggregate functions can be used as Window Functions - Stream windowAggFunctionVariants = + final Stream windowAggFunctionVariants = aggregateFunctionVariants.stream() .map( afi -> @@ -851,14 +853,14 @@ public static ExtensionCollection buildExtensionCollection( .windowType(SimpleExtension.WindowType.STREAMING) .build()); - List allWindowFunctionVariants = + final List allWindowFunctionVariants = Stream.concat(windowFunctionVariants, windowAggFunctionVariants) .collect(Collectors.toList()); - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put(uri, urn); - ImmutableSimpleExtension.ExtensionCollection collection = + final ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) .aggregateFunctions(aggregateFunctionVariants) diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index e514fb975..b5aa7b4f7 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -21,9 +21,9 @@ public synchronized Throwable fillInStackTrace() { } @Override - R accept(final TypeVisitor typeVisitor) throws E; + R accept(TypeVisitor typeVisitor) throws E; - static ParameterizedTypeCreator withNullability(boolean nullable) { + static ParameterizedTypeCreator withNullability(final boolean nullable) { return nullable ? ParameterizedTypeCreator.NULLABLE : ParameterizedTypeCreator.REQUIRED; } @@ -45,7 +45,7 @@ public final R accept(final TypeVisitor typeVisit } abstract R accept( - final ParameterizedTypeVisitor parameterizedTypeVisitor) throws E; + ParameterizedTypeVisitor parameterizedTypeVisitor) throws E; } @Value.Immutable diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 6b89840f6..80ad0d743 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -8,20 +8,21 @@ public class ParameterizedTypeCreator extends TypeCreator public static final ParameterizedTypeCreator REQUIRED = new ParameterizedTypeCreator(false); public static final ParameterizedTypeCreator NULLABLE = new ParameterizedTypeCreator(true); - protected ParameterizedTypeCreator(boolean nullable) { + protected ParameterizedTypeCreator(final boolean nullable) { super(nullable); } - private static ParameterizedType.StringLiteral parameter(String literal, boolean nullable) { + private static ParameterizedType.StringLiteral parameter( + final String literal, final boolean nullable) { return ParameterizedType.StringLiteral.builder().nullable(nullable).value(literal).build(); } - public ParameterizedType.StringLiteral parameter(String literal) { + public ParameterizedType.StringLiteral parameter(final String literal) { return parameter(literal, nullable); } @Override - public ParameterizedType fixedCharE(String len) { + public ParameterizedType fixedCharE(final String len) { return ParameterizedType.FixedChar.builder() .nullable(nullable) .length(parameter(len, false)) @@ -29,7 +30,7 @@ public ParameterizedType fixedCharE(String len) { } @Override - public ParameterizedType varCharE(String len) { + public ParameterizedType varCharE(final String len) { return ParameterizedType.VarChar.builder() .nullable(nullable) .length(parameter(len, false)) @@ -37,7 +38,7 @@ public ParameterizedType varCharE(String len) { } @Override - public ParameterizedType fixedBinaryE(String len) { + public ParameterizedType fixedBinaryE(final String len) { return ParameterizedType.FixedBinary.builder() .nullable(nullable) .length(parameter(len, false)) @@ -45,7 +46,7 @@ public ParameterizedType fixedBinaryE(String len) { } @Override - public ParameterizedType decimalE(String precision, String scale) { + public ParameterizedType decimalE(final String precision, final String scale) { return ParameterizedType.Decimal.builder() .nullable(nullable) .precision(parameter(precision, false)) @@ -53,28 +54,28 @@ public ParameterizedType decimalE(String precision, String scale) { .build(); } - public ParameterizedType intervalDayE(String precision) { + public ParameterizedType intervalDayE(final String precision) { return ParameterizedType.IntervalDay.builder() .nullable(nullable) .precision(parameter(precision, false)) .build(); } - public ParameterizedType intervalCompoundE(String precision) { + public ParameterizedType intervalCompoundE(final String precision) { return ParameterizedType.IntervalCompound.builder() .nullable(nullable) .precision(parameter(precision, false)) .build(); } - public ParameterizedType precisionTimestampE(String precision) { + public ParameterizedType precisionTimestampE(final String precision) { return ParameterizedType.PrecisionTimestamp.builder() .nullable(nullable) .precision(parameter(precision, false)) .build(); } - public ParameterizedType precisionTimestampTZE(String precision) { + public ParameterizedType precisionTimestampTZE(final String precision) { return ParameterizedType.PrecisionTimestampTZ.builder() .nullable(nullable) .precision(parameter(precision, false)) @@ -82,22 +83,22 @@ public ParameterizedType precisionTimestampTZE(String precision) { } @Override - public ParameterizedType structE(ParameterizedType... types) { + public ParameterizedType structE(final ParameterizedType... types) { return ParameterizedType.Struct.builder().nullable(nullable).addFields(types).build(); } @Override - public ParameterizedType structE(Iterable types) { + public ParameterizedType structE(final Iterable types) { return ParameterizedType.Struct.builder().nullable(nullable).addAllFields(types).build(); } @Override - public ParameterizedType listE(ParameterizedType type) { + public ParameterizedType listE(final ParameterizedType type) { return ParameterizedType.ListType.builder().nullable(nullable).name(type).build(); } @Override - public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) { + public ParameterizedType mapE(final ParameterizedType key, final ParameterizedType value) { return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build(); } } diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 9ff42f549..5e04b2422 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -32,72 +32,72 @@ public interface ParameterizedTypeVisitor extends TypeVi abstract class ParameterizedTypeThrowsVisitor extends TypeVisitor.TypeThrowsVisitor implements ParameterizedTypeVisitor { - protected ParameterizedTypeThrowsVisitor(String unsupportedMessage) { + protected ParameterizedTypeThrowsVisitor(final String unsupportedMessage) { super(unsupportedMessage); } @Override - public R visit(ParameterizedType.FixedChar expr) throws E { + public R visit(final ParameterizedType.FixedChar expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.VarChar expr) throws E { + public R visit(final ParameterizedType.VarChar expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.FixedBinary expr) throws E { + public R visit(final ParameterizedType.FixedBinary expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.Decimal expr) throws E { + public R visit(final ParameterizedType.Decimal expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.PrecisionTime expr) throws E { + public R visit(final ParameterizedType.PrecisionTime expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.PrecisionTimestamp expr) throws E { + public R visit(final ParameterizedType.PrecisionTimestamp expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.PrecisionTimestampTZ expr) throws E { + public R visit(final ParameterizedType.PrecisionTimestampTZ expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.IntervalDay expr) throws E { + public R visit(final ParameterizedType.IntervalDay expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.IntervalCompound expr) throws E { + public R visit(final ParameterizedType.IntervalCompound expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.Struct expr) throws E { + public R visit(final ParameterizedType.Struct expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.ListType expr) throws E { + public R visit(final ParameterizedType.ListType expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.Map expr) throws E { + public R visit(final ParameterizedType.Map expr) throws E { throw t(); } @Override - public R visit(ParameterizedType.StringLiteral stringLiteral) throws E { + public R visit(final ParameterizedType.StringLiteral stringLiteral) throws E { throw t(); } } diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index d6fc1bdb8..26dbda99b 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -7,7 +7,7 @@ public class ToTypeString public static final ToTypeString INSTANCE = new ToTypeString(); - public static String apply(Type type) { + public static String apply(final Type type) { return type.accept(INSTANCE); } @@ -156,62 +156,62 @@ public String visit(final Type.UserDefined expr) { } @Override - public String visit(ParameterizedType.FixedChar expr) throws RuntimeException { + public String visit(final ParameterizedType.FixedChar expr) throws RuntimeException { return "fchar"; } @Override - public String visit(ParameterizedType.VarChar expr) throws RuntimeException { + public String visit(final ParameterizedType.VarChar expr) throws RuntimeException { return "vchar"; } @Override - public String visit(ParameterizedType.FixedBinary expr) throws RuntimeException { + public String visit(final ParameterizedType.FixedBinary expr) throws RuntimeException { return "fbinary"; } @Override - public String visit(ParameterizedType.Decimal expr) throws RuntimeException { + public String visit(final ParameterizedType.Decimal expr) throws RuntimeException { return "dec"; } @Override - public String visit(ParameterizedType.IntervalDay expr) throws RuntimeException { + public String visit(final ParameterizedType.IntervalDay expr) throws RuntimeException { return "iday"; } @Override - public String visit(ParameterizedType.IntervalCompound expr) throws RuntimeException { + public String visit(final ParameterizedType.IntervalCompound expr) throws RuntimeException { return "icompound"; } @Override - public String visit(ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { + public String visit(final ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { return "pts"; } @Override - public String visit(ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { + public String visit(final ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { return "ptstz"; } @Override - public String visit(ParameterizedType.Struct expr) throws RuntimeException { + public String visit(final ParameterizedType.Struct expr) throws RuntimeException { return "struct"; } @Override - public String visit(ParameterizedType.ListType expr) throws RuntimeException { + public String visit(final ParameterizedType.ListType expr) throws RuntimeException { return "list"; } @Override - public String visit(ParameterizedType.Map expr) throws RuntimeException { + public String visit(final ParameterizedType.Map expr) throws RuntimeException { return "map"; } @Override - public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { + public String visit(final ParameterizedType.StringLiteral expr) throws RuntimeException { if (expr.value().toLowerCase().startsWith("any")) { return "any"; } else { @@ -233,7 +233,7 @@ public static class ToTypeLiteralStringLossless extends ToTypeString { private ToTypeLiteralStringLossless() {} @Override - public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { + public String visit(final ParameterizedType.StringLiteral expr) throws RuntimeException { return expr.value().toLowerCase(); } } diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index a183c1959..ddb93c184 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -10,9 +10,9 @@ class RequiredTypeExpressionVisitorException extends RuntimeException { private static final long serialVersionUID = 8381558691397737963L; } - R accept(final TypeVisitor typeVisitor) throws E; + R accept(TypeVisitor typeVisitor) throws E; - static TypeExpressionCreator withNullability(boolean nullable) { + static TypeExpressionCreator withNullability(final boolean nullable) { return nullable ? TypeExpressionCreator.NULLABLE : TypeExpressionCreator.REQUIRED; } @@ -26,7 +26,7 @@ public final R accept(final TypeVisitor typeVisit } abstract R acceptE( - final TypeExpressionVisitor parameterizedTypeVisitor) throws E; + TypeExpressionVisitor parameterizedTypeVisitor) throws E; } @Value.Immutable diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index b7524911b..a31ea01c8 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -9,27 +9,27 @@ public class TypeExpressionCreator extends TypeCreator public static final TypeExpressionCreator REQUIRED = new TypeExpressionCreator(false); public static final TypeExpressionCreator NULLABLE = new TypeExpressionCreator(true); - protected TypeExpressionCreator(boolean nullable) { + protected TypeExpressionCreator(final boolean nullable) { super(nullable); } @Override - public TypeExpression fixedCharE(TypeExpression len) { + public TypeExpression fixedCharE(final TypeExpression len) { return TypeExpression.FixedChar.builder().nullable(nullable).length(len).build(); } @Override - public TypeExpression varCharE(TypeExpression len) { + public TypeExpression varCharE(final TypeExpression len) { return TypeExpression.VarChar.builder().nullable(nullable).length(len).build(); } @Override - public TypeExpression fixedBinaryE(TypeExpression len) { + public TypeExpression fixedBinaryE(final TypeExpression len) { return TypeExpression.FixedBinary.builder().nullable(nullable).length(len).build(); } @Override - public TypeExpression decimalE(TypeExpression precision, TypeExpression scale) { + public TypeExpression decimalE(final TypeExpression precision, final TypeExpression scale) { return TypeExpression.Decimal.builder() .nullable(nullable) .scale(scale) @@ -37,25 +37,25 @@ public TypeExpression decimalE(TypeExpression precision, TypeExpression scale) { .build(); } - public TypeExpression intervalDayE(TypeExpression precision) { + public TypeExpression intervalDayE(final TypeExpression precision) { return TypeExpression.IntervalDay.builder().nullable(nullable).precision(precision).build(); } - public TypeExpression intervalCompoundE(TypeExpression precision) { + public TypeExpression intervalCompoundE(final TypeExpression precision) { return TypeExpression.IntervalCompound.builder() .nullable(nullable) .precision(precision) .build(); } - public TypeExpression precisionTimestampE(TypeExpression precision) { + public TypeExpression precisionTimestampE(final TypeExpression precision) { return TypeExpression.PrecisionTimestamp.builder() .nullable(nullable) .precision(precision) .build(); } - public TypeExpression precisionTimestampTZE(TypeExpression precision) { + public TypeExpression precisionTimestampTZE(final TypeExpression precision) { return TypeExpression.PrecisionTimestampTZ.builder() .nullable(nullable) .precision(precision) @@ -63,22 +63,22 @@ public TypeExpression precisionTimestampTZE(TypeExpression precision) { } @Override - public TypeExpression structE(TypeExpression... types) { + public TypeExpression structE(final TypeExpression... types) { return TypeExpression.Struct.builder().nullable(nullable).addFields(types).build(); } @Override - public TypeExpression structE(Iterable types) { + public TypeExpression structE(final Iterable types) { return TypeExpression.Struct.builder().nullable(nullable).addAllFields(types).build(); } @Override - public TypeExpression listE(TypeExpression type) { + public TypeExpression listE(final TypeExpression type) { return TypeExpression.ListType.builder().nullable(nullable).elementType(type).build(); } @Override - public TypeExpression mapE(TypeExpression key, TypeExpression value) { + public TypeExpression mapE(final TypeExpression key, final TypeExpression value) { return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build(); } @@ -103,7 +103,8 @@ public TypeExpression expr() { } ; - public static TypeExpression program(TypeExpression finalExpr, Assign... assignments) { + public static TypeExpression program( + final TypeExpression finalExpr, final Assign... assignments) { return TypeExpression.ReturnProgram.builder() .finalExpression(finalExpr) .addAllAssignments( @@ -118,20 +119,22 @@ public static TypeExpression program(TypeExpression finalExpr, Assign... assignm .build(); } - public static TypeExpression plus(TypeExpression left, TypeExpression right) { + public static TypeExpression plus(final TypeExpression left, final TypeExpression right) { return binary(TypeExpression.BinaryOperation.OpType.ADD, left, right); } - public static TypeExpression minus(TypeExpression left, TypeExpression right) { + public static TypeExpression minus(final TypeExpression left, final TypeExpression right) { return binary(TypeExpression.BinaryOperation.OpType.SUBTRACT, left, right); } public static TypeExpression binary( - TypeExpression.BinaryOperation.OpType op, TypeExpression left, TypeExpression right) { + final TypeExpression.BinaryOperation.OpType op, + final TypeExpression left, + final TypeExpression right) { return TypeExpression.BinaryOperation.builder().opType(op).left(left).right(right).build(); } - public static TypeExpression.IntegerLiteral i(int i) { + public static TypeExpression.IntegerLiteral i(final int i) { return TypeExpression.IntegerLiteral.builder().value(i).build(); } } diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 31d632c71..05b39111a 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -38,87 +38,87 @@ abstract class TypeExpressionThrowsVisitor extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor implements TypeExpressionVisitor { - protected TypeExpressionThrowsVisitor(String unsupportedMessage) { + protected TypeExpressionThrowsVisitor(final String unsupportedMessage) { super(unsupportedMessage); } @Override - public R visit(TypeExpression.FixedChar expr) throws E { + public R visit(final TypeExpression.FixedChar expr) throws E { throw t(); } @Override - public R visit(TypeExpression.VarChar expr) throws E { + public R visit(final TypeExpression.VarChar expr) throws E { throw t(); } @Override - public R visit(TypeExpression.FixedBinary expr) throws E { + public R visit(final TypeExpression.FixedBinary expr) throws E { throw t(); } @Override - public R visit(TypeExpression.Decimal expr) throws E { + public R visit(final TypeExpression.Decimal expr) throws E { throw t(); } @Override - public R visit(TypeExpression.PrecisionTimestamp expr) throws E { + public R visit(final TypeExpression.PrecisionTimestamp expr) throws E { throw t(); } @Override - public R visit(TypeExpression.PrecisionTimestampTZ expr) throws E { + public R visit(final TypeExpression.PrecisionTimestampTZ expr) throws E { throw t(); } @Override - public R visit(TypeExpression.IntervalDay expr) throws E { + public R visit(final TypeExpression.IntervalDay expr) throws E { throw t(); } @Override - public R visit(TypeExpression.IntervalCompound expr) throws E { + public R visit(final TypeExpression.IntervalCompound expr) throws E { throw t(); } @Override - public R visit(TypeExpression.Struct expr) throws E { + public R visit(final TypeExpression.Struct expr) throws E { throw t(); } @Override - public R visit(TypeExpression.ListType expr) throws E { + public R visit(final TypeExpression.ListType expr) throws E { throw t(); } @Override - public R visit(TypeExpression.Map expr) throws E { + public R visit(final TypeExpression.Map expr) throws E { throw t(); } @Override - public R visit(TypeExpression.BinaryOperation expr) throws E { + public R visit(final TypeExpression.BinaryOperation expr) throws E { throw t(); } @Override - public R visit(TypeExpression.NotOperation expr) throws E { + public R visit(final TypeExpression.NotOperation expr) throws E { throw t(); } @Override - public R visit(TypeExpression.IfOperation expr) throws E { + public R visit(final TypeExpression.IfOperation expr) throws E { throw t(); } @Override - public R visit(TypeExpression.IntegerLiteral expr) throws E { + public R visit(final TypeExpression.IntegerLiteral expr) throws E { throw t(); } @Override - public R visit(TypeExpression.ReturnProgram expr) throws E { + public R visit(final TypeExpression.ReturnProgram expr) throws E { throw t(); } } diff --git a/core/src/main/java/io/substrait/hint/Hint.java b/core/src/main/java/io/substrait/hint/Hint.java index 580110840..573c49ca4 100644 --- a/core/src/main/java/io/substrait/hint/Hint.java +++ b/core/src/main/java/io/substrait/hint/Hint.java @@ -28,7 +28,7 @@ public enum ComputationType { private final RelCommon.Hint.ComputationType proto; - ComputationType(RelCommon.Hint.ComputationType compType) { + ComputationType(final RelCommon.Hint.ComputationType compType) { this.proto = compType; } @@ -36,7 +36,7 @@ public RelCommon.Hint.ComputationType toProto() { return this.proto; } - public static ComputationType fromProto(RelCommon.Hint.ComputationType proto) { + public static ComputationType fromProto(final RelCommon.Hint.ComputationType proto) { for (final ComputationType compTypePojo : values()) { if (compTypePojo.proto == proto) { return compTypePojo; diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index 12501c1ac..e5eec543b 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -66,18 +66,18 @@ protected ProtoRelConverter getProtoRelConverter(final ExtensionLookup functionL return new ProtoRelConverter(functionLookup, this.extensionCollection, protoExtensionConverter); } - public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = + public Plan from(final io.substrait.proto.Plan plan) { + final ExtensionLookup functionLookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); - ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); - List roots = new ArrayList<>(); - for (PlanRel planRel : plan.getRelationsList()) { - io.substrait.proto.RelRoot root = planRel.getRoot(); - Rel rel = relConverter.from(root.getInput()); + final ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); + final List roots = new ArrayList<>(); + for (final PlanRel planRel : plan.getRelationsList()) { + final io.substrait.proto.RelRoot root = planRel.getRoot(); + final Rel rel = relConverter.from(root.getInput()); roots.add(Plan.Root.builder().input(rel).names(root.getNamesList()).build()); } - ImmutableVersion.Builder versionBuilder = + final ImmutableVersion.Builder versionBuilder = ImmutableVersion.builder() .major(plan.getVersion().getMajorNumber()) .minor(plan.getVersion().getMinorNumber()) diff --git a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java index 54f2504bc..5c9682d81 100644 --- a/core/src/main/java/io/substrait/relation/AbstractDdlRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractDdlRel.java @@ -24,7 +24,7 @@ public enum DdlObject { private final DdlRel.DdlObject proto; - DdlObject(DdlRel.DdlObject proto) { + DdlObject(final DdlRel.DdlObject proto) { this.proto = proto; } @@ -32,8 +32,8 @@ public DdlRel.DdlObject toProto() { return proto; } - public static DdlObject fromProto(DdlRel.DdlObject proto) { - for (DdlObject v : values()) { + public static DdlObject fromProto(final DdlRel.DdlObject proto) { + for (final DdlObject v : values()) { if (v.proto == proto) { return v; } @@ -52,7 +52,7 @@ public enum DdlOp { private final DdlRel.DdlOp proto; - DdlOp(DdlRel.DdlOp proto) { + DdlOp(final DdlRel.DdlOp proto) { this.proto = proto; } @@ -60,8 +60,8 @@ public DdlRel.DdlOp toProto() { return proto; } - public static DdlOp fromProto(DdlRel.DdlOp proto) { - for (DdlOp v : values()) { + public static DdlOp fromProto(final DdlRel.DdlOp proto) { + for (final DdlOp v : values()) { if (v.proto == proto) { return v; } diff --git a/core/src/main/java/io/substrait/relation/AbstractRel.java b/core/src/main/java/io/substrait/relation/AbstractRel.java index db6df00f7..30a8a1ca7 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractRel.java @@ -9,7 +9,7 @@ public abstract class AbstractRel implements Rel { private Supplier recordType = Util.memoize( () -> { - Type.Struct s = deriveRecordType(); + final Type.Struct s = deriveRecordType(); return getRemap().map(r -> r.remap(s)).orElse(s); }); diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 02418ce1d..b77b4a2d1 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -15,157 +15,158 @@ public abstract class AbstractRelVisitor O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java index 08776ed5d..66e7866dc 100644 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java @@ -22,17 +22,17 @@ public class AggregateFunctionProtoConverter { private final TypeProtoConverter typeProtoConverter; private final ExtensionCollector functionCollector; - public AggregateFunctionProtoConverter(ExtensionCollector functionCollector) { + public AggregateFunctionProtoConverter(final ExtensionCollector functionCollector) { this.functionCollector = functionCollector; this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null); this.typeProtoConverter = new TypeProtoConverter(functionCollector); } - public AggregateFunction toProto(Aggregate.Measure measure) { - FunctionArg.FuncArgVisitor + public AggregateFunction toProto(final Aggregate.Measure measure) { + final FunctionArg.FuncArgVisitor argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - List args = measure.getFunction().arguments(); - SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); + final List args = measure.getFunction().arguments(); + final SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); return AggregateFunction.newBuilder() .setPhase(measure.getFunction().aggregationPhase().toProto()) diff --git a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java index 1c736d34c..0eb1aa05d 100644 --- a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java +++ b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java @@ -25,7 +25,7 @@ public abstract class ConsistentPartitionWindow extends SingleInputRel implement @Override protected Type.Struct deriveRecordType() { - Type.Struct initial = getInput().getRecordType(); + final Type.Struct initial = getInput().getRecordType(); return TypeCreator.of(initial.nullable()) .struct( Stream.concat( @@ -35,7 +35,7 @@ protected Type.Struct deriveRecordType() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java b/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java index e470c526b..74ef98bff 100644 --- a/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java +++ b/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java @@ -10,12 +10,13 @@ /** Provides common utilities for copy-on-write visitations */ public class CopyOnWriteUtils { - public static boolean allEmpty(Optional... optionals) { + public static boolean allEmpty(final Optional... optionals) { return Arrays.stream(optionals).noneMatch(Optional::isPresent); } /** The `or` method on Optional instances is a Java 9+ feature */ - public static Optional or(Optional left, Supplier> right) { + public static Optional or( + final Optional left, final Supplier> right) { if (left.isPresent()) { return left; } else { @@ -44,11 +45,12 @@ public interface TransformFunction Optional> transformList( - List items, C context, TransformFunction transform) throws E { - List newItems = new ArrayList<>(); + final List items, final C context, final TransformFunction transform) + throws E { + final List newItems = new ArrayList<>(); boolean listUpdated = false; - for (I item : items) { - Optional newItem = transform.apply(item, context); + for (final I item : items) { + final Optional newItem = transform.apply(item, context); if (newItem.isPresent()) { newItems.add(newItem.get()); listUpdated = true; diff --git a/core/src/main/java/io/substrait/relation/Cross.java b/core/src/main/java/io/substrait/relation/Cross.java index b6ab4b42f..bcddd8150 100644 --- a/core/src/main/java/io/substrait/relation/Cross.java +++ b/core/src/main/java/io/substrait/relation/Cross.java @@ -19,7 +19,7 @@ protected Type.Struct deriveRecordType() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/EmptyScan.java b/core/src/main/java/io/substrait/relation/EmptyScan.java index 95d304b49..015b5df65 100644 --- a/core/src/main/java/io/substrait/relation/EmptyScan.java +++ b/core/src/main/java/io/substrait/relation/EmptyScan.java @@ -8,7 +8,7 @@ public abstract class EmptyScan extends AbstractReadRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java index 67bed0220..1304b9139 100644 --- a/core/src/main/java/io/substrait/relation/Expand.java +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -15,14 +15,14 @@ public abstract class Expand extends SingleInputRel { @Override public Type.Struct deriveRecordType() { - Type.Struct initial = getInput().getRecordType(); + final Type.Struct initial = getInput().getRecordType(); return TypeCreator.of(initial.nullable()) .struct(getFields().stream().map(ExpandField::getType)); } @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } @@ -54,8 +54,8 @@ public abstract static class SwitchingField implements ExpandField { @Override public Type getType() { - boolean nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); - Type type = getDuplicates().get(0).getType(); + final boolean nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); + final Type type = getDuplicates().get(0).getType(); return nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type); } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 57132a940..20cfe9f65 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -16,7 +16,7 @@ public class ExpressionCopyOnWriteVisitor private final RelCopyOnWriteVisitor relCopyOnWriteVisitor; - public ExpressionCopyOnWriteVisitor(RelCopyOnWriteVisitor relCopyOnWriteVisitor) { + public ExpressionCopyOnWriteVisitor(final RelCopyOnWriteVisitor relCopyOnWriteVisitor) { this.relCopyOnWriteVisitor = relCopyOnWriteVisitor; } @@ -25,197 +25,200 @@ protected final RelCopyOnWriteVisitor getRelCopyOnWriteVisitor() { } /** Utility method for visiting literals. By default, visits to literal types call this. */ - public Optional visitLiteral(Expression.Literal literal) { + public Optional visitLiteral(final Expression.Literal literal) { return Optional.empty(); } @Override - public Optional visit(Expression.NullLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.NullLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.BoolLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.BoolLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I8Literal expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.I8Literal expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I16Literal expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.I16Literal expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I32Literal expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.I32Literal expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I64Literal expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.I64Literal expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.FP32Literal expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.FP32Literal expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.FP64Literal expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.FP64Literal expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.StrLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.StrLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.BinaryLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.BinaryLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.TimeLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.TimeLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.DateLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.DateLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.TimestampLiteral expr, EmptyVisitationContext context) throws E { + final Expression.TimestampLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.TimestampTZLiteral expr, EmptyVisitationContext context) throws E { + final Expression.TimestampTZLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.PrecisionTimestampLiteral expr, EmptyVisitationContext context) throws E { + final Expression.PrecisionTimestampLiteral expr, final EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.PrecisionTimestampTZLiteral expr, EmptyVisitationContext context) throws E { + final Expression.PrecisionTimestampTZLiteral expr, final EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.IntervalYearLiteral expr, EmptyVisitationContext context) throws E { + final Expression.IntervalYearLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.IntervalDayLiteral expr, EmptyVisitationContext context) throws E { + final Expression.IntervalDayLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.IntervalCompoundLiteral expr, EmptyVisitationContext context) throws E { + final Expression.IntervalCompoundLiteral expr, final EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.UUIDLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.UUIDLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.FixedCharLiteral expr, EmptyVisitationContext context) throws E { + final Expression.FixedCharLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.VarCharLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.VarCharLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.FixedBinaryLiteral expr, EmptyVisitationContext context) throws E { + final Expression.FixedBinaryLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.DecimalLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.DecimalLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.MapLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.MapLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.EmptyMapLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.EmptyMapLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.ListLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.ListLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.EmptyListLiteral expr, EmptyVisitationContext context) throws E { + final Expression.EmptyListLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.StructLiteral expr, EmptyVisitationContext context) - throws E { + public Optional visit( + final Expression.StructLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { + final Expression.UserDefinedLiteral expr, final EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.Switch expr, EmptyVisitationContext context) - throws E { - Optional match = expr.match().accept(this, context); - Optional> switchClauses = + public Optional visit( + final Expression.Switch expr, final EmptyVisitationContext context) throws E { + final Optional match = expr.match().accept(this, context); + final Optional> switchClauses = transformList(expr.switchClauses(), context, this::visitSwitchClause); - Optional defaultClause = expr.defaultClause().accept(this, context); + final Optional defaultClause = expr.defaultClause().accept(this, context); if (allEmpty(match, switchClauses, defaultClause)) { return Optional.empty(); @@ -230,7 +233,7 @@ public Optional visit(Expression.Switch expr, EmptyVisitationContext } protected Optional visitSwitchClause( - Expression.SwitchClause switchClause, EmptyVisitationContext context) throws E { + final Expression.SwitchClause switchClause, final EmptyVisitationContext context) throws E { // This code does not visit the condition on the switch clause as that MUST be a Literal and the // visitor does not guarantee a Literal return type. If you wish to update the condition, // override this method. @@ -241,11 +244,11 @@ protected Optional visitSwitchClause( } @Override - public Optional visit(Expression.IfThen ifThen, EmptyVisitationContext context) - throws E { - Optional> ifClauses = + public Optional visit( + final Expression.IfThen ifThen, final EmptyVisitationContext context) throws E { + final Optional> ifClauses = transformList(ifThen.ifClauses(), context, this::visitIfClause); - Optional elseClause = ifThen.elseClause().accept(this, context); + final Optional elseClause = ifThen.elseClause().accept(this, context); if (allEmpty(ifClauses, elseClause)) { return Optional.empty(); @@ -259,9 +262,9 @@ public Optional visit(Expression.IfThen ifThen, EmptyVisitationConte } protected Optional visitIfClause( - Expression.IfClause ifClause, EmptyVisitationContext context) throws E { - Optional condition = ifClause.condition().accept(this, context); - Optional then = ifClause.then().accept(this, context); + final Expression.IfClause ifClause, final EmptyVisitationContext context) throws E { + final Optional condition = ifClause.condition().accept(this, context); + final Optional then = ifClause.then().accept(this, context); if (allEmpty(condition, then)) { return Optional.empty(); @@ -276,7 +279,8 @@ protected Optional visitIfClause( @Override public Optional visit( - Expression.ScalarFunctionInvocation sfi, EmptyVisitationContext context) throws E { + final Expression.ScalarFunctionInvocation sfi, final EmptyVisitationContext context) + throws E { return visitFunctionArguments(sfi.arguments(), context) .map( arguments -> @@ -288,10 +292,11 @@ public Optional visit( @Override public Optional visit( - Expression.WindowFunctionInvocation wfi, EmptyVisitationContext context) throws E { - Optional> arguments = visitFunctionArguments(wfi.arguments(), context); - Optional> partitionBy = visitExprList(wfi.partitionBy(), context); - Optional> sort = + final Expression.WindowFunctionInvocation wfi, final EmptyVisitationContext context) + throws E { + final Optional> arguments = visitFunctionArguments(wfi.arguments(), context); + final Optional> partitionBy = visitExprList(wfi.partitionBy(), context); + final Optional> sort = transformList(wfi.sort(), context, this::visitSortField); if (allEmpty(arguments, partitionBy, sort)) { @@ -307,7 +312,8 @@ public Optional visit( } @Override - public Optional visit(Expression.Cast cast, EmptyVisitationContext context) throws E { + public Optional visit( + final Expression.Cast cast, final EmptyVisitationContext context) throws E { return cast.input() .accept(this, context) .map(input -> Expression.Cast.builder().from(cast).input(input).build()); @@ -315,9 +321,9 @@ public Optional visit(Expression.Cast cast, EmptyVisitationContext c @Override public Optional visit( - Expression.SingleOrList singleOrList, EmptyVisitationContext context) throws E { - Optional condition = singleOrList.condition().accept(this, context); - Optional> options = visitExprList(singleOrList.options(), context); + final Expression.SingleOrList singleOrList, final EmptyVisitationContext context) throws E { + final Optional condition = singleOrList.condition().accept(this, context); + final Optional> options = visitExprList(singleOrList.options(), context); if (allEmpty(condition, options)) { return Optional.empty(); @@ -332,9 +338,9 @@ public Optional visit( @Override public Optional visit( - Expression.MultiOrList multiOrList, EmptyVisitationContext context) throws E { - Optional> conditions = visitExprList(multiOrList.conditions(), context); - Optional> optionCombinations = + final Expression.MultiOrList multiOrList, final EmptyVisitationContext context) throws E { + final Optional> conditions = visitExprList(multiOrList.conditions(), context); + final Optional> optionCombinations = transformList(multiOrList.optionCombinations(), context, this::visitMultiOrListRecord); if (allEmpty(conditions, optionCombinations)) { @@ -349,7 +355,8 @@ public Optional visit( } protected Optional visitMultiOrListRecord( - Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E { + final Expression.MultiOrListRecord multiOrListRecord, final EmptyVisitationContext context) + throws E { return visitExprList(multiOrListRecord.values(), context) .map( values -> @@ -360,9 +367,9 @@ protected Optional visitMultiOrListRecord( } @Override - public Optional visit(FieldReference fieldReference, EmptyVisitationContext context) - throws E { - Optional inputExpression = + public Optional visit( + final FieldReference fieldReference, final EmptyVisitationContext context) throws E { + final Optional inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { @@ -373,7 +380,7 @@ public Optional visit(FieldReference fieldReference, EmptyVisitation @Override public Optional visit( - Expression.SetPredicate setPredicate, EmptyVisitationContext context) throws E { + final Expression.SetPredicate setPredicate, final EmptyVisitationContext context) throws E { return setPredicate .tuples() .accept(getRelCopyOnWriteVisitor(), context) @@ -382,7 +389,8 @@ public Optional visit( @Override public Optional visit( - Expression.ScalarSubquery scalarSubquery, EmptyVisitationContext context) throws E { + final Expression.ScalarSubquery scalarSubquery, final EmptyVisitationContext context) + throws E { return scalarSubquery .input() .accept(getRelCopyOnWriteVisitor(), context) @@ -392,9 +400,10 @@ public Optional visit( @Override public Optional visit( - Expression.InPredicate inPredicate, EmptyVisitationContext context) throws E { - Optional haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); - Optional> needles = visitExprList(inPredicate.needles(), context); + final Expression.InPredicate inPredicate, final EmptyVisitationContext context) throws E { + final Optional haystack = + inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); + final Optional> needles = visitExprList(inPredicate.needles(), context); if (allEmpty(haystack, needles)) { return Optional.empty(); @@ -410,12 +419,12 @@ public Optional visit( // utilities protected Optional> visitExprList( - List exprs, EmptyVisitationContext context) throws E { + final List exprs, final EmptyVisitationContext context) throws E { return transformList(exprs, context, (e, c) -> e.accept(this, c)); } private Optional visitOptionalExpression( - Optional optExpr, EmptyVisitationContext context) throws E { + final Optional optExpr, final EmptyVisitationContext context) throws E { // not using optExpr.map to allow us to propagate the EXCEPTION nicely if (optExpr.isPresent()) { return optExpr.get().accept(this, context); @@ -424,7 +433,7 @@ private Optional visitOptionalExpression( } protected Optional> visitFunctionArguments( - List funcArgs, EmptyVisitationContext context) throws E { + final List funcArgs, final EmptyVisitationContext context) throws E { return CopyOnWriteUtils.transformList( funcArgs, context, @@ -438,7 +447,7 @@ protected Optional> visitFunctionArguments( } protected Optional visitSortField( - Expression.SortField sortField, EmptyVisitationContext context) throws E { + final Expression.SortField sortField, final EmptyVisitationContext context) throws E { return sortField .expr() .accept(this, context) diff --git a/core/src/main/java/io/substrait/relation/ExtensionDdl.java b/core/src/main/java/io/substrait/relation/ExtensionDdl.java index b95fc0c53..f56c1460c 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionDdl.java +++ b/core/src/main/java/io/substrait/relation/ExtensionDdl.java @@ -9,7 +9,7 @@ public abstract class ExtensionDdl extends AbstractDdlRel implements HasExtensio @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/ExtensionLeaf.java b/core/src/main/java/io/substrait/relation/ExtensionLeaf.java index 7b990ae19..d5937a436 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionLeaf.java +++ b/core/src/main/java/io/substrait/relation/ExtensionLeaf.java @@ -10,11 +10,11 @@ public abstract class ExtensionLeaf extends ZeroInputRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } - public static ImmutableExtensionLeaf.Builder from(Extension.LeafRelDetail detail) { + public static ImmutableExtensionLeaf.Builder from(final Extension.LeafRelDetail detail) { return ImmutableExtensionLeaf.builder() .detail(detail) .deriveRecordType(detail.deriveRecordType()); diff --git a/core/src/main/java/io/substrait/relation/ExtensionMulti.java b/core/src/main/java/io/substrait/relation/ExtensionMulti.java index 5ed3da08b..279690d9c 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionMulti.java +++ b/core/src/main/java/io/substrait/relation/ExtensionMulti.java @@ -13,17 +13,17 @@ public abstract class ExtensionMulti extends AbstractRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } public static ImmutableExtensionMulti.Builder from( - Extension.MultiRelDetail detail, Rel... inputs) { + final Extension.MultiRelDetail detail, final Rel... inputs) { return from(detail, Arrays.stream(inputs).collect(Collectors.toList())); } public static ImmutableExtensionMulti.Builder from( - Extension.MultiRelDetail detail, List inputs) { + final Extension.MultiRelDetail detail, final List inputs) { return ImmutableExtensionMulti.builder() .addAllInputs(inputs) .detail(detail) diff --git a/core/src/main/java/io/substrait/relation/ExtensionSingle.java b/core/src/main/java/io/substrait/relation/ExtensionSingle.java index 69edb97d3..20a6ecb4a 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionSingle.java +++ b/core/src/main/java/io/substrait/relation/ExtensionSingle.java @@ -10,11 +10,12 @@ public abstract class ExtensionSingle extends SingleInputRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } - public static ImmutableExtensionSingle.Builder from(Extension.SingleRelDetail detail, Rel input) { + public static ImmutableExtensionSingle.Builder from( + final Extension.SingleRelDetail detail, final Rel input) { return ImmutableExtensionSingle.builder() .input(input) .detail(detail) diff --git a/core/src/main/java/io/substrait/relation/ExtensionTable.java b/core/src/main/java/io/substrait/relation/ExtensionTable.java index 5cbc4231e..52e82bb85 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionTable.java +++ b/core/src/main/java/io/substrait/relation/ExtensionTable.java @@ -10,11 +10,11 @@ public abstract class ExtensionTable extends AbstractReadRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } - public static ImmutableExtensionTable.Builder from(Extension.ExtensionTableDetail detail) { + public static ImmutableExtensionTable.Builder from(final Extension.ExtensionTableDetail detail) { return ImmutableExtensionTable.builder().initialSchema(detail.deriveSchema()).detail(detail); } diff --git a/core/src/main/java/io/substrait/relation/ExtensionWrite.java b/core/src/main/java/io/substrait/relation/ExtensionWrite.java index 72e39e84e..bb243571c 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionWrite.java +++ b/core/src/main/java/io/substrait/relation/ExtensionWrite.java @@ -9,7 +9,7 @@ public abstract class ExtensionWrite extends AbstractWriteRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/Fetch.java b/core/src/main/java/io/substrait/relation/Fetch.java index a27726571..e2cc038da 100644 --- a/core/src/main/java/io/substrait/relation/Fetch.java +++ b/core/src/main/java/io/substrait/relation/Fetch.java @@ -19,7 +19,7 @@ protected Type.Struct deriveRecordType() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/Filter.java b/core/src/main/java/io/substrait/relation/Filter.java index e7fdfd3bb..9ecd75957 100644 --- a/core/src/main/java/io/substrait/relation/Filter.java +++ b/core/src/main/java/io/substrait/relation/Filter.java @@ -17,7 +17,7 @@ protected Type.Struct deriveRecordType() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index adb9cd535..0d8743a56 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -42,7 +42,7 @@ public enum JoinType { private JoinRel.JoinType proto; - JoinType(JoinRel.JoinType proto) { + JoinType(final JoinRel.JoinType proto) { this.proto = proto; } @@ -50,8 +50,8 @@ public JoinRel.JoinType toProto() { return proto; } - public static JoinType fromProto(JoinRel.JoinType proto) { - for (JoinType v : values()) { + public static JoinType fromProto(final JoinRel.JoinType proto) { + for (final JoinType v : values()) { if (v.proto == proto) { return v; } @@ -63,8 +63,8 @@ public static JoinType fromProto(JoinRel.JoinType proto) { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = getLeftTypes(); - Stream rightTypes = getRightTypes(); + final Stream leftTypes = getLeftTypes(); + final Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } @@ -108,7 +108,7 @@ private Stream getRightTypes() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/LocalFiles.java b/core/src/main/java/io/substrait/relation/LocalFiles.java index cd3e5b9c5..871a510c5 100644 --- a/core/src/main/java/io/substrait/relation/LocalFiles.java +++ b/core/src/main/java/io/substrait/relation/LocalFiles.java @@ -12,7 +12,7 @@ public abstract class LocalFiles extends AbstractReadRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/NamedDdl.java b/core/src/main/java/io/substrait/relation/NamedDdl.java index 873e4b481..73a99d114 100644 --- a/core/src/main/java/io/substrait/relation/NamedDdl.java +++ b/core/src/main/java/io/substrait/relation/NamedDdl.java @@ -10,7 +10,7 @@ public abstract class NamedDdl extends AbstractDdlRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/NamedScan.java b/core/src/main/java/io/substrait/relation/NamedScan.java index 225a5b27b..ec766d92f 100644 --- a/core/src/main/java/io/substrait/relation/NamedScan.java +++ b/core/src/main/java/io/substrait/relation/NamedScan.java @@ -11,7 +11,7 @@ public abstract class NamedScan extends AbstractReadRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/NamedUpdate.java b/core/src/main/java/io/substrait/relation/NamedUpdate.java index f17947c85..7d062f1f8 100644 --- a/core/src/main/java/io/substrait/relation/NamedUpdate.java +++ b/core/src/main/java/io/substrait/relation/NamedUpdate.java @@ -11,7 +11,7 @@ public abstract class NamedUpdate extends AbstractUpdate { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/NamedWrite.java b/core/src/main/java/io/substrait/relation/NamedWrite.java index bfe087d0f..bd58e88a0 100644 --- a/core/src/main/java/io/substrait/relation/NamedWrite.java +++ b/core/src/main/java/io/substrait/relation/NamedWrite.java @@ -10,7 +10,7 @@ public abstract class NamedWrite extends AbstractWriteRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/Project.java b/core/src/main/java/io/substrait/relation/Project.java index 8100e5d7b..90e08e360 100644 --- a/core/src/main/java/io/substrait/relation/Project.java +++ b/core/src/main/java/io/substrait/relation/Project.java @@ -15,7 +15,7 @@ public abstract class Project extends SingleInputRel implements HasExtension { @Override public Type.Struct deriveRecordType() { - Type.Struct initial = getInput().getRecordType(); + final Type.Struct initial = getInput().getRecordType(); return TypeCreator.of(initial.nullable()) .struct( Stream.concat( @@ -24,7 +24,7 @@ public Type.Struct deriveRecordType() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java index c17245fb0..d3d6d554a 100644 --- a/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java @@ -24,14 +24,14 @@ public class ProtoAggregateFunctionConverter { private final ProtoExpressionConverter protoExpressionConverter; public ProtoAggregateFunctionConverter( - ExtensionLookup lookup, ProtoExpressionConverter protoExpressionConverter) { + final ExtensionLookup lookup, final ProtoExpressionConverter protoExpressionConverter) { this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExpressionConverter); } public ProtoAggregateFunctionConverter( - ExtensionLookup lookup, - SimpleExtension.ExtensionCollection extensions, - ProtoExpressionConverter protoExpressionConverter) { + final ExtensionLookup lookup, + final SimpleExtension.ExtensionCollection extensions, + final ProtoExpressionConverter protoExpressionConverter) { this.lookup = lookup; this.extensions = extensions; this.protoTypeConverter = new ProtoTypeConverter(lookup, extensions); @@ -39,20 +39,20 @@ public ProtoAggregateFunctionConverter( } public io.substrait.expression.AggregateFunctionInvocation from( - io.substrait.proto.AggregateFunction measure) { - FunctionArg.ProtoFrom protoFrom = + final io.substrait.proto.AggregateFunction measure) { + final FunctionArg.ProtoFrom protoFrom = new FunctionArg.ProtoFrom(protoExpressionConverter, protoTypeConverter); - SimpleExtension.AggregateFunctionVariant aggregateFunction = + final SimpleExtension.AggregateFunctionVariant aggregateFunction = lookup.getAggregateFunction(measure.getFunctionReference(), extensions); - List functionArgs = + final List functionArgs = IntStream.range(0, measure.getArgumentsCount()) .mapToObj(i -> protoFrom.convert(aggregateFunction, i, measure.getArguments(i))) .collect(java.util.stream.Collectors.toList()); - List options = + final List options = measure.getOptionsList().stream() .map(ProtoExpressionConverter::fromFunctionOption) .collect(Collectors.toList()); - List sorts = + final List sorts = measure.getSortsList().stream() .map(protoExpressionConverter::fromSortField) .collect(Collectors.toList()); diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index fca347f81..b90fd92d7 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -131,12 +131,12 @@ public ProtoRelConverter( this.protoExtensionConverter = protoExtensionConverter; } - public Plan.Root from(io.substrait.proto.RelRoot rel) { + public Plan.Root from(final io.substrait.proto.RelRoot rel) { return Plan.Root.builder().input(from(rel.getInput())).addAllNames(rel.getNamesList()).build(); } - public Rel from(io.substrait.proto.Rel rel) { - io.substrait.proto.Rel.RelTypeCase relType = rel.getRelTypeCase(); + public Rel from(final io.substrait.proto.Rel rel) { + final io.substrait.proto.Rel.RelTypeCase relType = rel.getRelTypeCase(); switch (relType) { case READ: return newRead(rel.getRead()); @@ -185,9 +185,9 @@ public Rel from(io.substrait.proto.Rel rel) { } } - protected Rel newRead(ReadRel rel) { + protected Rel newRead(final ReadRel rel) { if (rel.hasVirtualTable()) { - ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); + final ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); if (virtualTable.getValuesCount() == 0) { return newEmptyScan(rel); } else { @@ -318,13 +318,13 @@ protected ExtensionDdl newExtensionDdl(final DdlRel rel) { return builder.build(); } - protected Optional optionalViewDefinition(DdlRel rel) { + protected Optional optionalViewDefinition(final DdlRel rel) { return Optional.ofNullable(rel.hasViewDefinition() ? from(rel.getViewDefinition()) : null); } protected Expression.StructLiteral tableDefaults( - io.substrait.proto.Expression.Literal.Struct struct, NamedStruct tableSchema) { - ProtoExpressionConverter converter = + final io.substrait.proto.Expression.Literal.Struct struct, final NamedStruct tableSchema) { + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); return Expression.StructLiteral.builder() .fields( @@ -334,8 +334,8 @@ protected Expression.StructLiteral tableDefaults( .build(); } - protected Rel newUpdate(UpdateRel rel) { - UpdateRel.UpdateTypeCase relType = rel.getUpdateTypeCase(); + protected Rel newUpdate(final UpdateRel rel) { + final UpdateRel.UpdateTypeCase relType = rel.getUpdateTypeCase(); switch (relType) { case NAMED_TABLE: return newNamedUpdate(rel); @@ -344,20 +344,20 @@ protected Rel newUpdate(UpdateRel rel) { } } - protected Rel newNamedUpdate(UpdateRel rel) { - NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); - ProtoExpressionConverter converter = + protected Rel newNamedUpdate(final UpdateRel rel) { + final NamedStruct tableSchema = newNamedStruct(rel.getTableSchema()); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, tableSchema.struct(), this); - List transformations = + final List transformations = new ArrayList<>(rel.getTransformationsCount()); - for (UpdateRel.TransformExpression transformation : rel.getTransformationsList()) { + for (final UpdateRel.TransformExpression transformation : rel.getTransformationsList()) { transformations.add( NamedUpdate.TransformExpression.builder() .transformation(converter.from(transformation.getTransformation())) .columnTarget(transformation.getColumnTarget()) .build()); } - ImmutableNamedUpdate.Builder builder = + final ImmutableNamedUpdate.Builder builder = NamedUpdate.builder() .names(rel.getNamedTable().getNamesList()) .tableSchema(tableSchema) @@ -369,9 +369,9 @@ protected Rel newNamedUpdate(UpdateRel rel) { return builder.build(); } - protected Filter newFilter(FilterRel rel) { - Rel input = from(rel.getInput()); - ImmutableFilter.Builder builder = + protected Filter newFilter(final FilterRel rel) { + final Rel input = from(rel.getInput()); + final ImmutableFilter.Builder builder = Filter.builder() .input(input) .condition( @@ -387,12 +387,12 @@ protected Filter newFilter(FilterRel rel) { return builder.build(); } - protected NamedStruct newNamedStruct(ReadRel rel) { + protected NamedStruct newNamedStruct(final ReadRel rel) { return newNamedStruct(rel.getBaseSchema()); } - protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) { - io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); + protected NamedStruct newNamedStruct(final io.substrait.proto.NamedStruct namedStruct) { + final io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); return NamedStruct.builder() .names(namedStruct.getNamesList()) .struct( @@ -406,9 +406,9 @@ protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) .build(); } - protected EmptyScan newEmptyScan(ReadRel rel) { - NamedStruct namedStruct = newNamedStruct(rel); - ImmutableEmptyScan.Builder builder = + protected EmptyScan newEmptyScan(final ReadRel rel) { + final NamedStruct namedStruct = newNamedStruct(rel); + final ImmutableEmptyScan.Builder builder = EmptyScan.builder() .initialSchema(namedStruct) .bestEffortFilter( @@ -436,9 +436,9 @@ protected EmptyScan newEmptyScan(ReadRel rel) { return builder.build(); } - protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { - Extension.LeafRelDetail detail = detailFromExtensionLeafRel(rel.getDetail()); - ImmutableExtensionLeaf.Builder builder = + protected ExtensionLeaf newExtensionLeaf(final ExtensionLeafRel rel) { + final Extension.LeafRelDetail detail = detailFromExtensionLeafRel(rel.getDetail()); + final ImmutableExtensionLeaf.Builder builder = ExtensionLeaf.from(detail) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -446,10 +446,10 @@ protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { return builder.build(); } - protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { - Extension.SingleRelDetail detail = detailFromExtensionSingleRel(rel.getDetail()); - Rel input = from(rel.getInput()); - ImmutableExtensionSingle.Builder builder = + protected ExtensionSingle newExtensionSingle(final ExtensionSingleRel rel) { + final Extension.SingleRelDetail detail = detailFromExtensionSingleRel(rel.getDetail()); + final Rel input = from(rel.getInput()); + final ImmutableExtensionSingle.Builder builder = ExtensionSingle.from(detail, input) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -457,10 +457,11 @@ protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { return builder.build(); } - protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { - Extension.MultiRelDetail detail = detailFromExtensionMultiRel(rel.getDetail()); - List inputs = rel.getInputsList().stream().map(this::from).collect(Collectors.toList()); - ImmutableExtensionMulti.Builder builder = + protected ExtensionMulti newExtensionMulti(final ExtensionMultiRel rel) { + final Extension.MultiRelDetail detail = detailFromExtensionMultiRel(rel.getDetail()); + final List inputs = + rel.getInputsList().stream().map(this::from).collect(Collectors.toList()); + final ImmutableExtensionMulti.Builder builder = ExtensionMulti.from(detail, inputs) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) @@ -471,9 +472,9 @@ protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { return builder.build(); } - protected NamedScan newNamedScan(ReadRel rel) { - NamedStruct namedStruct = newNamedStruct(rel); - ImmutableNamedScan.Builder builder = + protected NamedScan newNamedScan(final ReadRel rel) { + final NamedStruct namedStruct = newNamedStruct(rel); + final ImmutableNamedScan.Builder builder = NamedScan.builder() .initialSchema(namedStruct) .names(rel.getNamedTable().getNamesList()) @@ -519,10 +520,10 @@ protected ExtensionTable newExtensionTable(final ReadRel rel) { return builder.build(); } - protected LocalFiles newLocalFiles(ReadRel rel) { - NamedStruct namedStruct = newNamedStruct(rel); + protected LocalFiles newLocalFiles(final ReadRel rel) { + final NamedStruct namedStruct = newNamedStruct(rel); - ImmutableLocalFiles.Builder builder = + final ImmutableLocalFiles.Builder builder = LocalFiles.builder() .initialSchema(namedStruct) .addAllItems( @@ -554,8 +555,8 @@ protected LocalFiles newLocalFiles(ReadRel rel) { return builder.build(); } - protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { - io.substrait.relation.files.ImmutableFileOrFiles.Builder builder = + protected FileOrFiles newFileOrFiles(final ReadRel.LocalFiles.FileOrFiles file) { + final io.substrait.relation.files.ImmutableFileOrFiles.Builder builder = FileOrFiles.builder() .partitionIndex(file.getPartitionIndex()) .start(file.getStart()) @@ -569,7 +570,8 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { } else if (file.hasDwrf()) { builder.fileFormat(FileFormat.DwrfReadOptions.builder().build()); } else if (file.hasText()) { - io.substrait.relation.files.ImmutableFileFormat.DelimiterSeparatedTextReadOptions.Builder + final io.substrait.relation.files.ImmutableFileFormat.DelimiterSeparatedTextReadOptions + .Builder ffBuilder = FileFormat.DelimiterSeparatedTextReadOptions.builder() .fieldDelimiter(file.getText().getFieldDelimiter()) @@ -596,13 +598,14 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { return builder.build(); } - protected VirtualTableScan newVirtualTable(ReadRel rel) { - ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); - NamedStruct virtualTableSchema = newNamedStruct(rel); - ProtoExpressionConverter converter = + protected VirtualTableScan newVirtualTable(final ReadRel rel) { + final ReadRel.VirtualTable virtualTable = rel.getVirtualTable(); + final NamedStruct virtualTableSchema = newNamedStruct(rel); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this); - List structLiterals = new ArrayList<>(virtualTable.getValuesCount()); - for (io.substrait.proto.Expression.Literal.Struct struct : virtualTable.getValuesList()) { + final List structLiterals = + new ArrayList<>(virtualTable.getValuesCount()); + for (final io.substrait.proto.Expression.Literal.Struct struct : virtualTable.getValuesList()) { structLiterals.add( Expression.StructLiteral.builder() .fields( @@ -612,7 +615,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { .build()); } - ImmutableVirtualTableScan.Builder builder = + final ImmutableVirtualTableScan.Builder builder = VirtualTableScan.builder() .bestEffortFilter( Optional.ofNullable( @@ -631,9 +634,9 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { return builder.build(); } - protected Fetch newFetch(FetchRel rel) { - Rel input = from(rel.getInput()); - ImmutableFetch.Builder builder = Fetch.builder().input(input).offset(rel.getOffset()); + protected Fetch newFetch(final FetchRel rel) { + final Rel input = from(rel.getInput()); + final ImmutableFetch.Builder builder = Fetch.builder().input(input).offset(rel.getOffset()); if (rel.getCount() != -1) { // -1 is used as a sentinel value to signal LIMIT ALL // count only needs to be set when it is not -1 @@ -650,11 +653,11 @@ protected Fetch newFetch(FetchRel rel) { return builder.build(); } - protected Project newProject(ProjectRel rel) { - Rel input = from(rel.getInput()); - ProtoExpressionConverter converter = + protected Project newProject(final ProjectRel rel) { + final Rel input = from(rel.getInput()); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - ImmutableProject.Builder builder = + final ImmutableProject.Builder builder = Project.builder() .input(input) .expressions( @@ -672,11 +675,11 @@ protected Project newProject(ProjectRel rel) { return builder.build(); } - protected Expand newExpand(ExpandRel rel) { - Rel input = from(rel.getInput()); - ProtoExpressionConverter converter = + protected Expand newExpand(final ExpandRel rel) { + final Rel input = from(rel.getInput()); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - ImmutableExpand.Builder builder = + final ImmutableExpand.Builder builder = Expand.builder() .input(input) .fields( @@ -709,26 +712,26 @@ protected Expand newExpand(ExpandRel rel) { return builder.build(); } - protected Aggregate newAggregate(AggregateRel rel) { - Rel input = from(rel.getInput()); - ProtoExpressionConverter protoExprConverter = + protected Aggregate newAggregate(final AggregateRel rel) { + final Rel input = from(rel.getInput()); + final ProtoExpressionConverter protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - ProtoAggregateFunctionConverter protoAggrFuncConverter = + final ProtoAggregateFunctionConverter protoAggrFuncConverter = new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); - List groupings = new ArrayList<>(rel.getGroupingsCount()); + final List groupings = new ArrayList<>(rel.getGroupingsCount()); // Groupings are set using the AggregateRel grouping_expression mechanism if (!rel.getGroupingExpressionsList().isEmpty()) { - List allGroupingExpressions = + final List allGroupingExpressions = rel.getGroupingExpressionsList().stream() .map(protoExprConverter::from) .collect(java.util.stream.Collectors.toList()); - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List references = grouping.getExpressionReferencesList(); - List groupExpressions = new ArrayList<>(); - for (int ref : references) { + for (final AggregateRel.Grouping grouping : rel.getGroupingsList()) { + final List references = grouping.getExpressionReferencesList(); + final List groupExpressions = new ArrayList<>(); + for (final int ref : references) { groupExpressions.add(allGroupingExpressions.get(ref)); } groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build()); @@ -736,7 +739,7 @@ protected Aggregate newAggregate(AggregateRel rel) { } else { // Groupings are set using the deprecated Grouping grouping_expressions mechanism - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + for (final AggregateRel.Grouping grouping : rel.getGroupingsList()) { groupings.add( Aggregate.Grouping.builder() .expressions( @@ -747,8 +750,8 @@ protected Aggregate newAggregate(AggregateRel rel) { } } - List measures = new ArrayList<>(rel.getMeasuresCount()); - for (AggregateRel.Measure measure : rel.getMeasuresList()) { + final List measures = new ArrayList<>(rel.getMeasuresCount()); + for (final AggregateRel.Measure measure : rel.getMeasuresList()) { measures.add( Aggregate.Measure.builder() .function(protoAggrFuncConverter.from(measure.getMeasure())) @@ -757,7 +760,7 @@ protected Aggregate newAggregate(AggregateRel rel) { measure.hasFilter() ? protoExprConverter.from(measure.getFilter()) : null)) .build()); } - ImmutableAggregate.Builder builder = + final ImmutableAggregate.Builder builder = Aggregate.builder().input(input).groupings(groupings).measures(measures); builder @@ -770,11 +773,11 @@ protected Aggregate newAggregate(AggregateRel rel) { return builder.build(); } - protected Sort newSort(SortRel rel) { - Rel input = from(rel.getInput()); - ProtoExpressionConverter converter = + protected Sort newSort(final SortRel rel) { + final Rel input = from(rel.getInput()); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - ImmutableSort.Builder builder = + final ImmutableSort.Builder builder = Sort.builder() .input(input) .sortFields( @@ -797,15 +800,16 @@ protected Sort newSort(SortRel rel) { return builder.build(); } - protected Join newJoin(JoinRel rel) { - Rel left = from(rel.getLeft()); - Rel right = from(rel.getRight()); - Type.Struct leftStruct = left.getRecordType(); - Type.Struct rightStruct = right.getRecordType(); - Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - ProtoExpressionConverter converter = + protected Join newJoin(final JoinRel rel) { + final Rel left = from(rel.getLeft()); + final Rel right = from(rel.getRight()); + final Type.Struct leftStruct = left.getRecordType(); + final Type.Struct rightStruct = right.getRecordType(); + final Type.Struct unionedStruct = + Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - ImmutableJoin.Builder builder = + final ImmutableJoin.Builder builder = Join.builder() .left(left) .right(right) @@ -825,10 +829,10 @@ protected Join newJoin(JoinRel rel) { return builder.build(); } - protected Rel newCross(CrossRel rel) { - Rel left = from(rel.getLeft()); - Rel right = from(rel.getRight()); - ImmutableCross.Builder builder = Cross.builder().left(left).right(right); + protected Rel newCross(final CrossRel rel) { + final Rel left = from(rel.getLeft()); + final Rel right = from(rel.getRight()); + final ImmutableCross.Builder builder = Cross.builder().left(left).right(right); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) @@ -839,12 +843,12 @@ protected Rel newCross(CrossRel rel) { return builder.build(); } - protected Set newSet(SetRel rel) { - List inputs = + protected Set newSet(final SetRel rel) { + final List inputs = rel.getInputsList().stream() .map(inputRel -> from(inputRel)) .collect(java.util.stream.Collectors.toList()); - ImmutableSet.Builder builder = + final ImmutableSet.Builder builder = Set.builder().inputs(inputs).setOp(Set.SetOp.fromProto(rel.getOp())); builder @@ -857,22 +861,23 @@ protected Set newSet(SetRel rel) { return builder.build(); } - protected Rel newHashJoin(HashJoinRel rel) { - Rel left = from(rel.getLeft()); - Rel right = from(rel.getRight()); - List leftKeys = rel.getLeftKeysList(); - List rightKeys = rel.getRightKeysList(); - - Type.Struct leftStruct = left.getRecordType(); - Type.Struct rightStruct = right.getRecordType(); - Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - ProtoExpressionConverter leftConverter = + protected Rel newHashJoin(final HashJoinRel rel) { + final Rel left = from(rel.getLeft()); + final Rel right = from(rel.getRight()); + final List leftKeys = rel.getLeftKeysList(); + final List rightKeys = rel.getRightKeysList(); + + final Type.Struct leftStruct = left.getRecordType(); + final Type.Struct rightStruct = right.getRecordType(); + final Type.Struct unionedStruct = + Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + final ProtoExpressionConverter leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); - ProtoExpressionConverter rightConverter = + final ProtoExpressionConverter rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); - ProtoExpressionConverter unionConverter = + final ProtoExpressionConverter unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - io.substrait.relation.physical.ImmutableHashJoin.Builder builder = + final io.substrait.relation.physical.ImmutableHashJoin.Builder builder = HashJoin.builder() .left(left) .right(right) @@ -892,22 +897,23 @@ protected Rel newHashJoin(HashJoinRel rel) { return builder.build(); } - protected Rel newMergeJoin(MergeJoinRel rel) { - Rel left = from(rel.getLeft()); - Rel right = from(rel.getRight()); - List leftKeys = rel.getLeftKeysList(); - List rightKeys = rel.getRightKeysList(); - - Type.Struct leftStruct = left.getRecordType(); - Type.Struct rightStruct = right.getRecordType(); - Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - ProtoExpressionConverter leftConverter = + protected Rel newMergeJoin(final MergeJoinRel rel) { + final Rel left = from(rel.getLeft()); + final Rel right = from(rel.getRight()); + final List leftKeys = rel.getLeftKeysList(); + final List rightKeys = rel.getRightKeysList(); + + final Type.Struct leftStruct = left.getRecordType(); + final Type.Struct rightStruct = right.getRecordType(); + final Type.Struct unionedStruct = + Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + final ProtoExpressionConverter leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); - ProtoExpressionConverter rightConverter = + final ProtoExpressionConverter rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); - ProtoExpressionConverter unionConverter = + final ProtoExpressionConverter unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - io.substrait.relation.physical.ImmutableMergeJoin.Builder builder = + final io.substrait.relation.physical.ImmutableMergeJoin.Builder builder = MergeJoin.builder() .left(left) .right(right) @@ -928,15 +934,16 @@ protected Rel newMergeJoin(MergeJoinRel rel) { return builder.build(); } - protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { - Rel left = from(rel.getLeft()); - Rel right = from(rel.getRight()); - Type.Struct leftStruct = left.getRecordType(); - Type.Struct rightStruct = right.getRecordType(); - Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - ProtoExpressionConverter converter = + protected NestedLoopJoin newNestedLoopJoin(final NestedLoopJoinRel rel) { + final Rel left = from(rel.getLeft()); + final Rel right = from(rel.getRight()); + final Type.Struct leftStruct = left.getRecordType(); + final Type.Struct rightStruct = right.getRecordType(); + final Type.Struct unionedStruct = + Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + final ProtoExpressionConverter converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); - io.substrait.relation.physical.ImmutableNestedLoopJoin.Builder builder = + final io.substrait.relation.physical.ImmutableNestedLoopJoin.Builder builder = NestedLoopJoin.builder() .left(left) .right(right) @@ -958,26 +965,26 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { } protected ConsistentPartitionWindow newConsistentPartitionWindow( - ConsistentPartitionWindowRel rel) { + final ConsistentPartitionWindowRel rel) { - Rel input = from(rel.getInput()); - ProtoExpressionConverter protoExpressionConverter = + final Rel input = from(rel.getInput()); + final ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - List partitionExprs = + final List partitionExprs = rel.getPartitionExpressionsList().stream() .map(protoExpressionConverter::from) .collect(Collectors.toList()); - List sortFields = + final List sortFields = rel.getSortsList().stream() .map(protoExpressionConverter::fromSortField) .collect(Collectors.toList()); - List windowRelFunctions = + final List windowRelFunctions = rel.getWindowFunctionsList().stream() .map(protoExpressionConverter::fromWindowRelFunction) .collect(Collectors.toList()); - ImmutableConsistentPartitionWindow.Builder builder = + final ImmutableConsistentPartitionWindow.Builder builder = ConsistentPartitionWindow.builder() .input(input) .partitionExpressions(partitionExprs) @@ -994,8 +1001,8 @@ protected ConsistentPartitionWindow newConsistentPartitionWindow( return builder.build(); } - protected AbstractExchangeRel newExchange(ExchangeRel rel) { - ExchangeRel.ExchangeKindCase exchangeKind = rel.getExchangeKindCase(); + protected AbstractExchangeRel newExchange(final ExchangeRel rel) { + final ExchangeRel.ExchangeKindCase exchangeKind = rel.getExchangeKindCase(); switch (exchangeKind) { case SCATTER_BY_FIELDS: return newScatterExchange(rel); @@ -1012,19 +1019,19 @@ protected AbstractExchangeRel newExchange(ExchangeRel rel) { } } - protected ScatterExchange newScatterExchange(ExchangeRel rel) { - Rel input = from(rel.getInput()); - List targets = + protected ScatterExchange newScatterExchange(final ExchangeRel rel) { + final Rel input = from(rel.getInput()); + final List targets = rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); - ProtoExpressionConverter protoExprConverter = + final ProtoExpressionConverter protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - List fieldReferences = + final List fieldReferences = rel.getScatterByFields().getFieldsList().stream() .map(protoExprConverter::from) .collect(Collectors.toList()); - ImmutableScatterExchange.Builder builder = + final ImmutableScatterExchange.Builder builder = ScatterExchange.builder() .input(input) .addAllFields(fieldReferences) @@ -1041,14 +1048,14 @@ protected ScatterExchange newScatterExchange(ExchangeRel rel) { return builder.build(); } - protected SingleBucketExchange newSingleBucketExchange(ExchangeRel rel) { - Rel input = from(rel.getInput()); - List targets = + protected SingleBucketExchange newSingleBucketExchange(final ExchangeRel rel) { + final Rel input = from(rel.getInput()); + final List targets = rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); - ProtoExpressionConverter protoExprConverter = + final ProtoExpressionConverter protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - ImmutableSingleBucketExchange.Builder builder = + final ImmutableSingleBucketExchange.Builder builder = SingleBucketExchange.builder() .input(input) .partitionCount(rel.getPartitionCount()) @@ -1065,14 +1072,14 @@ protected SingleBucketExchange newSingleBucketExchange(ExchangeRel rel) { return builder.build(); } - protected MultiBucketExchange newMultiBucketExchange(ExchangeRel rel) { - Rel input = from(rel.getInput()); - List targets = + protected MultiBucketExchange newMultiBucketExchange(final ExchangeRel rel) { + final Rel input = from(rel.getInput()); + final List targets = rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); - ProtoExpressionConverter protoExprConverter = + final ProtoExpressionConverter protoExprConverter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); - ImmutableMultiBucketExchange.Builder builder = + final ImmutableMultiBucketExchange.Builder builder = MultiBucketExchange.builder() .input(input) .partitionCount(rel.getPartitionCount()) @@ -1090,12 +1097,12 @@ protected MultiBucketExchange newMultiBucketExchange(ExchangeRel rel) { return builder.build(); } - protected RoundRobinExchange newRoundRobinExchange(ExchangeRel rel) { - Rel input = from(rel.getInput()); - List targets = + protected RoundRobinExchange newRoundRobinExchange(final ExchangeRel rel) { + final Rel input = from(rel.getInput()); + final List targets = rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); - ImmutableRoundRobinExchange.Builder builder = + final ImmutableRoundRobinExchange.Builder builder = RoundRobinExchange.builder() .input(input) .partitionCount(rel.getPartitionCount()) @@ -1112,12 +1119,12 @@ protected RoundRobinExchange newRoundRobinExchange(ExchangeRel rel) { return builder.build(); } - protected BroadcastExchange newBroadcastExchange(ExchangeRel rel) { - Rel input = from(rel.getInput()); - List targets = + protected BroadcastExchange newBroadcastExchange(final ExchangeRel rel) { + final Rel input = from(rel.getInput()); + final List targets = rel.getTargetsList().stream().map(this::newExchangeTarget).collect(Collectors.toList()); - ImmutableBroadcastExchange.Builder builder = + final ImmutableBroadcastExchange.Builder builder = BroadcastExchange.builder() .input(input) .partitionCount(rel.getPartitionCount()) @@ -1134,8 +1141,8 @@ protected BroadcastExchange newBroadcastExchange(ExchangeRel rel) { } protected AbstractExchangeRel.ExchangeTarget newExchangeTarget( - ExchangeRel.ExchangeTarget target) { - ImmutableExchangeTarget.Builder builder = AbstractExchangeRel.ExchangeTarget.builder(); + final ExchangeRel.ExchangeTarget target) { + final ImmutableExchangeTarget.Builder builder = AbstractExchangeRel.ExchangeTarget.builder(); builder.addAllPartitionIds(target.getPartitionIdList()); switch (target.getTargetTypeCase()) { case URI: @@ -1151,15 +1158,16 @@ protected AbstractExchangeRel.ExchangeTarget newExchangeTarget( return builder.build(); } - protected static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon) { + protected static Optional optionalRelmap( + final io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); } - protected Optional optionalHint(io.substrait.proto.RelCommon relCommon) { + protected Optional optionalHint(final io.substrait.proto.RelCommon relCommon) { if (!relCommon.hasHint()) return Optional.empty(); - io.substrait.proto.RelCommon.Hint hint = relCommon.getHint(); - io.substrait.hint.ImmutableHint.Builder builder = + final io.substrait.proto.RelCommon.Hint hint = relCommon.getHint(); + final io.substrait.hint.ImmutableHint.Builder builder = Hint.builder().addAllOutputNames(hint.getOutputNamesList()); if (!hint.getAlias().isEmpty()) { builder.alias(hint.getAlias()); @@ -1168,8 +1176,8 @@ protected Optional optionalHint(io.substrait.proto.RelCommon relCommon) { builder.extension(protoExtensionConverter.fromProto(hint.getAdvancedExtension())); } if (hint.hasStats()) { - io.substrait.proto.RelCommon.Hint.Stats stats = hint.getStats(); - io.substrait.hint.ImmutableStats.Builder statsBuilder = Stats.builder(); + final io.substrait.proto.RelCommon.Hint.Stats stats = hint.getStats(); + final io.substrait.hint.ImmutableStats.Builder statsBuilder = Stats.builder(); statsBuilder.recordSize(stats.getRecordSize()).rowCount(stats.getRowCount()); if (stats.hasAdvancedExtension()) { statsBuilder.extension(protoExtensionConverter.fromProto(stats.getAdvancedExtension())); @@ -1177,8 +1185,8 @@ protected Optional optionalHint(io.substrait.proto.RelCommon relCommon) { builder.stats(statsBuilder.build()); } if (hint.hasConstraint()) { - io.substrait.proto.RelCommon.Hint.RuntimeConstraint constraint = hint.getConstraint(); - io.substrait.hint.ImmutableRuntimeConstraint.Builder constraintBuilder = + final io.substrait.proto.RelCommon.Hint.RuntimeConstraint constraint = hint.getConstraint(); + final io.substrait.hint.ImmutableRuntimeConstraint.Builder constraintBuilder = RuntimeConstraint.builder(); if (constraint.hasAdvancedExtension()) { constraintBuilder.extension( @@ -1208,7 +1216,7 @@ protected Optional optionalHint(io.substrait.proto.RelCommon relCommon) { } protected Optional optionalAdvancedExtension( - io.substrait.proto.RelCommon relCommon) { + final io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasAdvancedExtension() ? protoExtensionConverter.fromProto(relCommon.getAdvancedExtension()) @@ -1216,17 +1224,19 @@ protected Optional optionalAdvancedExtension( } /** Override to provide a custom converter for {@link ExtensionLeafRel#getDetail()} data */ - protected Extension.LeafRelDetail detailFromExtensionLeafRel(com.google.protobuf.Any any) { + protected Extension.LeafRelDetail detailFromExtensionLeafRel(final com.google.protobuf.Any any) { return emptyDetail(); } /** Override to provide a custom converter for {@link ExtensionSingleRel#getDetail()} data */ - protected Extension.SingleRelDetail detailFromExtensionSingleRel(com.google.protobuf.Any any) { + protected Extension.SingleRelDetail detailFromExtensionSingleRel( + final com.google.protobuf.Any any) { return emptyDetail(); } /** Override to provide a custom converter for {@link ExtensionMultiRel#getDetail()} data */ - protected Extension.MultiRelDetail detailFromExtensionMultiRel(com.google.protobuf.Any any) { + protected Extension.MultiRelDetail detailFromExtensionMultiRel( + final com.google.protobuf.Any any) { return emptyDetail(); } @@ -1234,16 +1244,18 @@ protected Extension.MultiRelDetail detailFromExtensionMultiRel(com.google.protob * Override to provide a custom converter for {@link * io.substrait.proto.ReadRel.ExtensionTable#getDetail()} data */ - protected Extension.ExtensionTableDetail detailFromExtensionTable(com.google.protobuf.Any any) { + protected Extension.ExtensionTableDetail detailFromExtensionTable( + final com.google.protobuf.Any any) { return emptyDetail(); } protected Extension.WriteExtensionObject detailFromWriteExtensionObject( - com.google.protobuf.Any any) { + final com.google.protobuf.Any any) { return emptyDetail(); } - protected Extension.DdlExtensionObject detailFromDdlExtensionObject(com.google.protobuf.Any any) { + protected Extension.DdlExtensionObject detailFromDdlExtensionObject( + final com.google.protobuf.Any any) { return emptyDetail(); } diff --git a/core/src/main/java/io/substrait/relation/Rel.java b/core/src/main/java/io/substrait/relation/Rel.java index ab2d67093..4f52f053b 100644 --- a/core/src/main/java/io/substrait/relation/Rel.java +++ b/core/src/main/java/io/substrait/relation/Rel.java @@ -29,16 +29,16 @@ public interface Rel { abstract class Remap { public abstract List indices(); - public Type.Struct remap(Type.Struct initial) { - List types = initial.fields(); + public Type.Struct remap(final Type.Struct initial) { + final List types = initial.fields(); return TypeCreator.of(initial.nullable()).struct(indices().stream().map(i -> types.get(i))); } - public static Remap of(Iterable fields) { + public static Remap of(final Iterable fields) { return ImmutableRemap.builder().addAllIndices(fields).build(); } - public static Remap offset(int start, int length) { + public static Remap offset(final int start, final int length) { return of( IntStream.range(start, start + length) .mapToObj(i -> i) diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index b144ede1d..e5b0063db 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -36,12 +36,12 @@ public RelCopyOnWriteVisitor() { this.expressionCopyOnWriteVisitor = new ExpressionCopyOnWriteVisitor<>(this); } - public RelCopyOnWriteVisitor(ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor) { + public RelCopyOnWriteVisitor(final ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor) { this.expressionCopyOnWriteVisitor = expressionCopyOnWriteVisitor; } public RelCopyOnWriteVisitor( - Function, ExpressionCopyOnWriteVisitor> fn) { + final Function, ExpressionCopyOnWriteVisitor> fn) { this.expressionCopyOnWriteVisitor = fn.apply(this); } @@ -50,11 +50,12 @@ protected ExpressionCopyOnWriteVisitor getExpressionCopyOnWriteVisitor() { } @Override - public Optional visit(Aggregate aggregate, EmptyVisitationContext context) throws E { - Optional input = aggregate.getInput().accept(this, context); - Optional> groupings = + public Optional visit(final Aggregate aggregate, final EmptyVisitationContext context) + throws E { + final Optional input = aggregate.getInput().accept(this, context); + final Optional> groupings = transformList(aggregate.getGroupings(), context, this::visitGrouping); - Optional> measures = + final Optional> measures = transformList(aggregate.getMeasures(), context, this::visitMeasure); if (allEmpty(input, groupings, measures)) { @@ -70,16 +71,16 @@ public Optional visit(Aggregate aggregate, EmptyVisitationContext context) } protected Optional visitGrouping( - Aggregate.Grouping grouping, EmptyVisitationContext context) throws E { + final Aggregate.Grouping grouping, final EmptyVisitationContext context) throws E { return visitExprList(grouping.getExpressions(), context) .map(exprs -> Aggregate.Grouping.builder().from(grouping).expressions(exprs).build()); } protected Optional visitMeasure( - Aggregate.Measure measure, EmptyVisitationContext context) throws E { - Optional preMeasureFilter = + final Aggregate.Measure measure, final EmptyVisitationContext context) throws E { + final Optional preMeasureFilter = visitOptionalExpression(measure.getPreMeasureFilter(), context); - Optional afi = + final Optional afi = visitAggregateFunction(measure.getFunction(), context); if (allEmpty(preMeasureFilter, afi)) { @@ -94,9 +95,9 @@ protected Optional visitMeasure( } protected Optional visitAggregateFunction( - AggregateFunctionInvocation afi, EmptyVisitationContext context) throws E { - Optional> arguments = visitFunctionArguments(afi.arguments(), context); - Optional> sort = + final AggregateFunctionInvocation afi, final EmptyVisitationContext context) throws E { + final Optional> arguments = visitFunctionArguments(afi.arguments(), context); + final Optional> sort = transformList(afi.sort(), context, this::visitSortField); if (allEmpty(arguments, sort)) { @@ -111,8 +112,9 @@ protected Optional visitAggregateFunction( } @Override - public Optional visit(EmptyScan emptyScan, EmptyVisitationContext context) throws E { - Optional filter = visitOptionalExpression(emptyScan.getFilter(), context); + public Optional visit(final EmptyScan emptyScan, final EmptyVisitationContext context) + throws E { + final Optional filter = visitOptionalExpression(emptyScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -125,7 +127,7 @@ public Optional visit(EmptyScan emptyScan, EmptyVisitationContext context) } @Override - public Optional visit(Fetch fetch, EmptyVisitationContext context) throws E { + public Optional visit(final Fetch fetch, final EmptyVisitationContext context) throws E { return fetch .getInput() .accept(this, context) @@ -133,9 +135,9 @@ public Optional visit(Fetch fetch, EmptyVisitationContext context) throws E } @Override - public Optional visit(Filter filter, EmptyVisitationContext context) throws E { - Optional input = filter.getInput().accept(this, context); - Optional condition = + public Optional visit(final Filter filter, final EmptyVisitationContext context) throws E { + final Optional input = filter.getInput().accept(this, context); + final Optional condition = filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(input, condition)) { @@ -150,11 +152,12 @@ public Optional visit(Filter filter, EmptyVisitationContext context) throws } @Override - public Optional visit(Join join, EmptyVisitationContext context) throws E { - Optional left = join.getLeft().accept(this, context); - Optional right = join.getRight().accept(this, context); - Optional condition = visitOptionalExpression(join.getCondition(), context); - Optional postFilter = visitOptionalExpression(join.getPostJoinFilter(), context); + public Optional visit(final Join join, final EmptyVisitationContext context) throws E { + final Optional left = join.getLeft().accept(this, context); + final Optional right = join.getRight().accept(this, context); + final Optional condition = visitOptionalExpression(join.getCondition(), context); + final Optional postFilter = + visitOptionalExpression(join.getPostJoinFilter(), context); if (allEmpty(left, right, condition, postFilter)) { return Optional.empty(); @@ -170,14 +173,15 @@ public Optional visit(Join join, EmptyVisitationContext context) throws E { } @Override - public Optional visit(Set set, EmptyVisitationContext context) throws E { + public Optional visit(final Set set, final EmptyVisitationContext context) throws E { return transformList(set.getInputs(), context, (t, c) -> t.accept(this, c)) .map(s -> Set.builder().from(set).inputs(s).build()); } @Override - public Optional visit(NamedScan namedScan, EmptyVisitationContext context) throws E { - Optional filter = visitOptionalExpression(namedScan.getFilter(), context); + public Optional visit(final NamedScan namedScan, final EmptyVisitationContext context) + throws E { + final Optional filter = visitOptionalExpression(namedScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -187,8 +191,9 @@ public Optional visit(NamedScan namedScan, EmptyVisitationContext context) } @Override - public Optional visit(LocalFiles localFiles, EmptyVisitationContext context) throws E { - Optional filter = visitOptionalExpression(localFiles.getFilter(), context); + public Optional visit(final LocalFiles localFiles, final EmptyVisitationContext context) + throws E { + final Optional filter = visitOptionalExpression(localFiles.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -198,9 +203,9 @@ public Optional visit(LocalFiles localFiles, EmptyVisitationContext context } @Override - public Optional visit(Project project, EmptyVisitationContext context) throws E { - Optional input = project.getInput().accept(this, context); - Optional> expressions = visitExprList(project.getExpressions(), context); + public Optional visit(final Project project, final EmptyVisitationContext context) throws E { + final Optional input = project.getInput().accept(this, context); + final Optional> expressions = visitExprList(project.getExpressions(), context); if (allEmpty(input, expressions)) { return Optional.empty(); @@ -214,14 +219,15 @@ public Optional visit(Project project, EmptyVisitationContext context) thro } @Override - public Optional visit(Expand expand, EmptyVisitationContext context) throws E { + public Optional visit(final Expand expand, final EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(NamedWrite write, EmptyVisitationContext context) throws E { + public Optional visit(final NamedWrite write, final EmptyVisitationContext context) + throws E { - Optional input = write.getInput().accept(this, context); + final Optional input = write.getInput().accept(this, context); if (allEmpty(input)) { return Optional.empty(); @@ -232,22 +238,25 @@ public Optional visit(NamedWrite write, EmptyVisitationContext context) thr } @Override - public Optional visit(ExtensionWrite write, EmptyVisitationContext context) throws E { + public Optional visit(final ExtensionWrite write, final EmptyVisitationContext context) + throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(NamedDdl ddl, EmptyVisitationContext context) throws E { + public Optional visit(final NamedDdl ddl, final EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(ExtensionDdl ddl, EmptyVisitationContext context) throws E { + public Optional visit(final ExtensionDdl ddl, final EmptyVisitationContext context) + throws E { throw new UnsupportedOperationException(); } protected Optional visitTransformExpression( - NamedUpdate.TransformExpression transform, EmptyVisitationContext context) throws E { + final NamedUpdate.TransformExpression transform, final EmptyVisitationContext context) + throws E { return transform .getTransformation() .accept(getExpressionCopyOnWriteVisitor(), context) @@ -260,11 +269,12 @@ protected Optional visitTransformExpression( } @Override - public Optional visit(NamedUpdate update, EmptyVisitationContext context) throws E { - Optional condition = + public Optional visit(final NamedUpdate update, final EmptyVisitationContext context) + throws E { + final Optional condition = update.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); - Optional> transformations = + final Optional> transformations = transformList(update.getTransformations(), context, this::visitTransformExpression); if (allEmpty(condition, transformations)) { @@ -280,9 +290,10 @@ public Optional visit(NamedUpdate update, EmptyVisitationContext context) t } @Override - public Optional visit(ScatterExchange exchange, EmptyVisitationContext context) throws E { - Optional input = exchange.getInput().accept(this, context); - Optional> fields = + public Optional visit(final ScatterExchange exchange, final EmptyVisitationContext context) + throws E { + final Optional input = exchange.getInput().accept(this, context); + final Optional> fields = transformList(exchange.getFields(), context, this::visitFieldReference); if (allEmpty(input, fields)) { @@ -298,11 +309,11 @@ public Optional visit(ScatterExchange exchange, EmptyVisitationContext cont } @Override - public Optional visit(SingleBucketExchange exchange, EmptyVisitationContext context) - throws E { - Optional input = exchange.getInput().accept(this, context); + public Optional visit( + final SingleBucketExchange exchange, final EmptyVisitationContext context) throws E { + final Optional input = exchange.getInput().accept(this, context); - Optional expression = + final Optional expression = exchange.getExpression().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(input, expression)) { @@ -318,10 +329,10 @@ public Optional visit(SingleBucketExchange exchange, EmptyVisitationContext } @Override - public Optional visit(MultiBucketExchange exchange, EmptyVisitationContext context) - throws E { - Optional input = exchange.getInput().accept(this, context); - Optional expression = + public Optional visit( + final MultiBucketExchange exchange, final EmptyVisitationContext context) throws E { + final Optional input = exchange.getInput().accept(this, context); + final Optional expression = exchange.getExpression().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(input)) { @@ -337,8 +348,9 @@ public Optional visit(MultiBucketExchange exchange, EmptyVisitationContext } @Override - public Optional visit(RoundRobinExchange exchange, EmptyVisitationContext context) throws E { - Optional input = exchange.getInput().accept(this, context); + public Optional visit( + final RoundRobinExchange exchange, final EmptyVisitationContext context) throws E { + final Optional input = exchange.getInput().accept(this, context); if (allEmpty(input)) { return Optional.empty(); } @@ -351,8 +363,9 @@ public Optional visit(RoundRobinExchange exchange, EmptyVisitationContext c } @Override - public Optional visit(BroadcastExchange exchange, EmptyVisitationContext context) throws E { - Optional input = exchange.getInput().accept(this, context); + public Optional visit(final BroadcastExchange exchange, final EmptyVisitationContext context) + throws E { + final Optional input = exchange.getInput().accept(this, context); if (allEmpty(input)) { return Optional.empty(); } @@ -365,9 +378,9 @@ public Optional visit(BroadcastExchange exchange, EmptyVisitationContext co } @Override - public Optional visit(Sort sort, EmptyVisitationContext context) throws E { - Optional input = sort.getInput().accept(this, context); - Optional> sortFields = + public Optional visit(final Sort sort, final EmptyVisitationContext context) throws E { + final Optional input = sort.getInput().accept(this, context); + final Optional> sortFields = transformList(sort.getSortFields(), context, this::visitSortField); if (allEmpty(input, sortFields)) { @@ -382,9 +395,9 @@ public Optional visit(Sort sort, EmptyVisitationContext context) throws E { } @Override - public Optional visit(Cross cross, EmptyVisitationContext context) throws E { - Optional left = cross.getLeft().accept(this, context); - Optional right = cross.getRight().accept(this, context); + public Optional visit(final Cross cross, final EmptyVisitationContext context) throws E { + final Optional left = cross.getLeft().accept(this, context); + final Optional right = cross.getRight().accept(this, context); if (allEmpty(left, right)) { return Optional.empty(); @@ -398,9 +411,10 @@ public Optional visit(Cross cross, EmptyVisitationContext context) throws E } @Override - public Optional visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) - throws E { - Optional filter = visitOptionalExpression(virtualTableScan.getFilter(), context); + public Optional visit( + final VirtualTableScan virtualTableScan, final EmptyVisitationContext context) throws E { + final Optional filter = + visitOptionalExpression(virtualTableScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -413,13 +427,14 @@ public Optional visit(VirtualTableScan virtualTableScan, EmptyVisitationCon } @Override - public Optional visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) throws E { + public Optional visit( + final ExtensionLeaf extensionLeaf, final EmptyVisitationContext context) throws E { return Optional.empty(); } @Override - public Optional visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) - throws E { + public Optional visit( + final ExtensionSingle extensionSingle, final EmptyVisitationContext context) throws E { return extensionSingle .getInput() .accept(this, context) @@ -427,16 +442,17 @@ public Optional visit(ExtensionSingle extensionSingle, EmptyVisitationConte } @Override - public Optional visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) - throws E { + public Optional visit( + final ExtensionMulti extensionMulti, final EmptyVisitationContext context) throws E { return transformList(extensionMulti.getInputs(), context, (rel, c) -> rel.accept(this, c)) .map(inputs -> ExtensionMulti.builder().from(extensionMulti).inputs(inputs).build()); } @Override - public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext context) - throws E { - Optional filter = visitOptionalExpression(extensionTable.getFilter(), context); + public Optional visit( + final ExtensionTable extensionTable, final EmptyVisitationContext context) throws E { + final Optional filter = + visitOptionalExpression(extensionTable.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -449,14 +465,15 @@ public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext } @Override - public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) throws E { - Optional left = hashJoin.getLeft().accept(this, context); - Optional right = hashJoin.getRight().accept(this, context); - Optional> leftKeys = + public Optional visit(final HashJoin hashJoin, final EmptyVisitationContext context) + throws E { + final Optional left = hashJoin.getLeft().accept(this, context); + final Optional right = hashJoin.getRight().accept(this, context); + final Optional> leftKeys = transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); - Optional> rightKeys = + final Optional> rightKeys = transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); - Optional postFilter = + final Optional postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { @@ -474,14 +491,15 @@ public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) th } @Override - public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E { - Optional left = mergeJoin.getLeft().accept(this, context); - Optional right = mergeJoin.getRight().accept(this, context); - Optional> leftKeys = + public Optional visit(final MergeJoin mergeJoin, final EmptyVisitationContext context) + throws E { + final Optional left = mergeJoin.getLeft().accept(this, context); + final Optional right = mergeJoin.getRight().accept(this, context); + final Optional> leftKeys = transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); - Optional> rightKeys = + final Optional> rightKeys = transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); - Optional postFilter = + final Optional postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { @@ -499,11 +517,11 @@ public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) } @Override - public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) - throws E { - Optional left = nestedLoopJoin.getLeft().accept(this, context); - Optional right = nestedLoopJoin.getRight().accept(this, context); - Optional condition = + public Optional visit( + final NestedLoopJoin nestedLoopJoin, final EmptyVisitationContext context) throws E { + final Optional left = nestedLoopJoin.getLeft().accept(this, context); + final Optional right = nestedLoopJoin.getRight().accept(this, context); + final Optional condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(left, right, condition)) { @@ -520,17 +538,18 @@ public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext @Override public Optional visit( - ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + final ConsistentPartitionWindow consistentPartitionWindow, + final EmptyVisitationContext context) throws E { - Optional> windowFunctions = + final Optional> windowFunctions = transformList( consistentPartitionWindow.getWindowFunctions(), context, this::visitWindowRelFunction); - Optional> partitionExpressions = + final Optional> partitionExpressions = transformList( consistentPartitionWindow.getPartitionExpressions(), context, (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c)); - Optional> sorts = + final Optional> sorts = transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField); if (allEmpty(windowFunctions, partitionExpressions, sorts)) { @@ -548,10 +567,10 @@ public Optional visit( } protected Optional visitWindowRelFunction( - ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation, - EmptyVisitationContext context) + final ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation, + final EmptyVisitationContext context) throws E { - Optional> functionArgs = + final Optional> functionArgs = visitFunctionArguments(windowRelFunctionInvocation.arguments(), context); if (allEmpty(functionArgs)) { @@ -568,13 +587,13 @@ protected Optional visitW // utilities protected Optional> visitExprList( - List exprs, EmptyVisitationContext context) throws E { + final List exprs, final EmptyVisitationContext context) throws E { return transformList(exprs, context, (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c)); } public Optional visitFieldReference( - FieldReference fieldReference, EmptyVisitationContext context) throws E { - Optional inputExpression = + final FieldReference fieldReference, final EmptyVisitationContext context) throws E { + final Optional inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); @@ -584,7 +603,7 @@ public Optional visitFieldReference( } protected Optional> visitFunctionArguments( - List funcArgs, EmptyVisitationContext context) throws E { + final List funcArgs, final EmptyVisitationContext context) throws E { return CopyOnWriteUtils.transformList( funcArgs, context, @@ -600,7 +619,7 @@ protected Optional> visitFunctionArguments( } protected Optional visitSortField( - Expression.SortField sortField, EmptyVisitationContext context) throws E { + final Expression.SortField sortField, final EmptyVisitationContext context) throws E { return sortField .expr() .accept(getExpressionCopyOnWriteVisitor(), context) @@ -608,7 +627,7 @@ protected Optional visitSortField( } private Optional visitOptionalExpression( - Optional optExpr, EmptyVisitationContext context) throws E { + final Optional optExpr, final EmptyVisitationContext context) throws E { // not using optExpr.map to allow us to propagate the THROWABLE nicely if (optExpr.isPresent()) { return optExpr.get().accept(getExpressionCopyOnWriteVisitor(), context); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 20ef1cce7..a59862c09 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -112,31 +112,32 @@ public TypeProtoConverter getTypeProtoConverter() { return this.typeProtoConverter; } - public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) { + public io.substrait.proto.RelRoot toProto(final Plan.Root relRoot) { return RelRoot.newBuilder() .setInput(toProto(relRoot.getInput())) .addAllNames(relRoot.getNames()) .build(); } - public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) { + public io.substrait.proto.Rel toProto(final io.substrait.relation.Rel rel) { return rel.accept(this, EmptyVisitationContext.INSTANCE); } - protected io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) { + protected io.substrait.proto.Expression toProto( + final io.substrait.expression.Expression expression) { return exprProtoConverter.toProto(expression); } protected List toProto( - List expression) { + final List expression) { return exprProtoConverter.toProto(expression); } - protected io.substrait.proto.Type toProto(io.substrait.type.Type type) { + protected io.substrait.proto.Type toProto(final io.substrait.type.Type type) { return typeProtoConverter.toProto(type); } - private List toProtoS(List sorts) { + private List toProtoS(final List sorts) { return sorts.stream() .map( s -> { @@ -148,13 +149,15 @@ private List toProtoS(List s .collect(Collectors.toList()); } - private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) { + private io.substrait.proto.Expression.FieldReference toProto( + final FieldReference fieldReference) { return toProto((Expression) fieldReference).getSelection(); } @Override - public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { - AggregateRel.Builder builder = + public Rel visit(final Aggregate aggregate, final EmptyVisitationContext context) + throws RuntimeException { + final AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) @@ -169,14 +172,14 @@ public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws Run return Rel.newBuilder().setAggregate(builder).build(); } - private AggregateRel.Measure toProto(Aggregate.Measure measure) { - FunctionArg.FuncArgVisitor< + private AggregateRel.Measure toProto(final Aggregate.Measure measure) { + final FunctionArg.FuncArgVisitor< io.substrait.proto.FunctionArgument, EmptyVisitationContext, RuntimeException> argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - List args = measure.getFunction().arguments(); - SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); + final List args = measure.getFunction().arguments(); + final SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration(); - AggregateFunction.Builder func = + final AggregateFunction.Builder func = AggregateFunction.newBuilder() .setPhase(measure.getFunction().aggregationPhase().toProto()) .setInvocation(measure.getFunction().invocation().toProto()) @@ -196,20 +199,20 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .map(ExpressionProtoConverter::from) .collect(Collectors.toList())); - AggregateRel.Measure.Builder builder = AggregateRel.Measure.newBuilder().setMeasure(func); + final AggregateRel.Measure.Builder builder = AggregateRel.Measure.newBuilder().setMeasure(func); measure.getPreMeasureFilter().ifPresent(f -> builder.setFilter(toProto(f))); return builder.build(); } - private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) { + private AggregateRel.Grouping toProto(final Aggregate.Grouping grouping) { return AggregateRel.Grouping.newBuilder() .addAllGroupingExpressions(toProto(grouping.getExpressions())) .build(); } @Override - public Rel visit(final EmptyScan emptyScan, EmptyVisitationContext context) + public Rel visit(final EmptyScan emptyScan, final EmptyVisitationContext context) throws RuntimeException { final ReadRel.Builder builder = ReadRel.newBuilder() @@ -223,8 +226,9 @@ public Rel visit(final EmptyScan emptyScan, EmptyVisitationContext context) } @Override - public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { - FetchRel.Builder builder = + public Rel visit(final Fetch fetch, final EmptyVisitationContext context) + throws RuntimeException { + final FetchRel.Builder builder = FetchRel.newBuilder() .setCommon(common(fetch)) .setInput(toProto(fetch.getInput())) @@ -239,8 +243,9 @@ public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeExce } @Override - public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { - FilterRel.Builder builder = + public Rel visit(final Filter filter, final EmptyVisitationContext context) + throws RuntimeException { + final FilterRel.Builder builder = FilterRel.newBuilder() .setCommon(common(filter)) .setInput(toProto(filter.getInput())) @@ -253,8 +258,8 @@ public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeEx } @Override - public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeException { - JoinRel.Builder builder = + public Rel visit(final Join join, final EmptyVisitationContext context) throws RuntimeException { + final JoinRel.Builder builder = JoinRel.newBuilder() .setCommon(common(join)) .setLeft(toProto(join.getLeft())) @@ -271,8 +276,8 @@ public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeExcept } @Override - public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeException { - SetRel.Builder builder = + public Rel visit(final Set set, final EmptyVisitationContext context) throws RuntimeException { + final SetRel.Builder builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); set.getInputs() .forEach( @@ -286,8 +291,9 @@ public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeExceptio } @Override - public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws RuntimeException { - ReadRel.Builder builder = + public Rel visit(final NamedScan namedScan, final EmptyVisitationContext context) + throws RuntimeException { + final ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(namedScan)) .setNamedTable(ReadRel.NamedTable.newBuilder().addAllNames(namedScan.getNames())) @@ -303,8 +309,9 @@ public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws Run } @Override - public Rel visit(LocalFiles localFiles, EmptyVisitationContext context) throws RuntimeException { - ReadRel.Builder builder = + public Rel visit(final LocalFiles localFiles, final EmptyVisitationContext context) + throws RuntimeException { + final ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(localFiles)) .setLocalFiles( @@ -325,11 +332,11 @@ public Rel visit(LocalFiles localFiles, EmptyVisitationContext context) throws R } @Override - public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) + public Rel visit(final ExtensionTable extensionTable, final EmptyVisitationContext context) throws RuntimeException { - ReadRel.ExtensionTable.Builder extensionTableBuilder = + final ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder().setDetail(extensionTable.getDetail().toProto(this)); - ReadRel.Builder builder = + final ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(extensionTable)) .setBaseSchema(extensionTable.getInitialSchema().toProto(typeProtoConverter)) @@ -342,16 +349,17 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) } @Override - public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { - HashJoinRel.Builder builder = + public Rel visit(final HashJoin hashJoin, final EmptyVisitationContext context) + throws RuntimeException { + final HashJoinRel.Builder builder = HashJoinRel.newBuilder() .setCommon(common(hashJoin)) .setLeft(toProto(hashJoin.getLeft())) .setRight(toProto(hashJoin.getRight())) .setType(hashJoin.getJoinType().toProto()); - List leftKeys = hashJoin.getLeftKeys(); - List rightKeys = hashJoin.getRightKeys(); + final List leftKeys = hashJoin.getLeftKeys(); + final List rightKeys = hashJoin.getRightKeys(); if (leftKeys.size() != rightKeys.size()) { throw new IllegalArgumentException("Number of left and right keys must be equal."); @@ -369,16 +377,17 @@ public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws Runti } @Override - public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws RuntimeException { - MergeJoinRel.Builder builder = + public Rel visit(final MergeJoin mergeJoin, final EmptyVisitationContext context) + throws RuntimeException { + final MergeJoinRel.Builder builder = MergeJoinRel.newBuilder() .setCommon(common(mergeJoin)) .setLeft(toProto(mergeJoin.getLeft())) .setRight(toProto(mergeJoin.getRight())) .setType(mergeJoin.getJoinType().toProto()); - List leftKeys = mergeJoin.getLeftKeys(); - List rightKeys = mergeJoin.getRightKeys(); + final List leftKeys = mergeJoin.getLeftKeys(); + final List rightKeys = mergeJoin.getRightKeys(); if (leftKeys.size() != rightKeys.size()) { throw new IllegalArgumentException("Number of left and right keys must be equal."); @@ -396,9 +405,9 @@ public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws Run } @Override - public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) + public Rel visit(final NestedLoopJoin nestedLoopJoin, final EmptyVisitationContext context) throws RuntimeException { - NestedLoopJoinRel.Builder builder = + final NestedLoopJoinRel.Builder builder = NestedLoopJoinRel.newBuilder() .setCommon(common(nestedLoopJoin)) .setLeft(toProto(nestedLoopJoin.getLeft())) @@ -414,9 +423,10 @@ public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) @Override public Rel visit( - ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + final ConsistentPartitionWindow consistentPartitionWindow, + final EmptyVisitationContext context) throws RuntimeException { - ConsistentPartitionWindowRel.Builder builder = + final ConsistentPartitionWindowRel.Builder builder = ConsistentPartitionWindowRel.newBuilder() .setCommon(common(consistentPartitionWindow)) .setInput(toProto(consistentPartitionWindow.getInput())) @@ -434,8 +444,9 @@ public Rel visit( } @Override - public Rel visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { - WriteRel.Builder builder = + public Rel visit(final NamedWrite write, final EmptyVisitationContext context) + throws RuntimeException { + final WriteRel.Builder builder = WriteRel.newBuilder() .setCommon(common(write)) .setInput(toProto(write.getInput())) @@ -453,8 +464,9 @@ public Rel visit(NamedWrite write, EmptyVisitationContext context) throws Runtim } @Override - public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws RuntimeException { - WriteRel.Builder builder = + public Rel visit(final ExtensionWrite write, final EmptyVisitationContext context) + throws RuntimeException { + final WriteRel.Builder builder = WriteRel.newBuilder() .setCommon(common(write)) .setInput(toProto(write.getInput())) @@ -473,8 +485,9 @@ public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws Ru } @Override - public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { - DdlRel.Builder builder = + public Rel visit(final NamedDdl ddl, final EmptyVisitationContext context) + throws RuntimeException { + final DdlRel.Builder builder = DdlRel.newBuilder() .setCommon(common(ddl)) .setTableSchema(ddl.getTableSchema().toProto(typeProtoConverter)) @@ -493,8 +506,9 @@ public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeExc } @Override - public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { - DdlRel.Builder builder = + public Rel visit(final ExtensionDdl ddl, final EmptyVisitationContext context) + throws RuntimeException { + final DdlRel.Builder builder = DdlRel.newBuilder() .setCommon(common(ddl)) .setTableSchema(ddl.getTableSchema().toProto(typeProtoConverter)) @@ -514,8 +528,9 @@ public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws Runtim } @Override - public Rel visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { - UpdateRel.Builder builder = + public Rel visit(final NamedUpdate update, final EmptyVisitationContext context) + throws RuntimeException { + final UpdateRel.Builder builder = UpdateRel.newBuilder() .setNamedTable(NamedTable.newBuilder().addAllNames(update.getNames())) .setTableSchema(update.getTableSchema().toProto(typeProtoConverter)) @@ -531,9 +546,9 @@ public Rel visit(NamedUpdate update, EmptyVisitationContext context) throws Runt } @Override - public Rel visit(ScatterExchange exchange, EmptyVisitationContext context) + public Rel visit(final ScatterExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - ExchangeRel.Builder builder = + final ExchangeRel.Builder builder = ExchangeRel.newBuilder() .setScatterByFields( ExchangeRel.ScatterFields.newBuilder() @@ -551,9 +566,9 @@ public Rel visit(ScatterExchange exchange, EmptyVisitationContext context) } @Override - public Rel visit(SingleBucketExchange exchange, EmptyVisitationContext context) + public Rel visit(final SingleBucketExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - ExchangeRel.Builder builder = + final ExchangeRel.Builder builder = ExchangeRel.newBuilder() .setSingleTarget( ExchangeRel.SingleBucketExpression.newBuilder() @@ -568,9 +583,9 @@ public Rel visit(SingleBucketExchange exchange, EmptyVisitationContext context) } @Override - public Rel visit(MultiBucketExchange exchange, EmptyVisitationContext context) + public Rel visit(final MultiBucketExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - ExchangeRel.Builder builder = + final ExchangeRel.Builder builder = ExchangeRel.newBuilder() .setMultiTarget( ExchangeRel.MultiBucketExpression.newBuilder() @@ -586,9 +601,9 @@ public Rel visit(MultiBucketExchange exchange, EmptyVisitationContext context) } @Override - public Rel visit(RoundRobinExchange exchange, EmptyVisitationContext context) + public Rel visit(final RoundRobinExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - ExchangeRel.Builder builder = + final ExchangeRel.Builder builder = ExchangeRel.newBuilder() .setRoundRobin( ExchangeRel.RoundRobin.newBuilder().setExact(exchange.getExact()).build()) @@ -601,9 +616,9 @@ public Rel visit(RoundRobinExchange exchange, EmptyVisitationContext context) } @Override - public Rel visit(BroadcastExchange exchange, EmptyVisitationContext context) + public Rel visit(final BroadcastExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - ExchangeRel.Builder builder = + final ExchangeRel.Builder builder = ExchangeRel.newBuilder() .setBroadcast(ExchangeRel.Broadcast.newBuilder().build()) .setPartitionCount(exchange.getPartitionCount()) @@ -614,8 +629,8 @@ public Rel visit(BroadcastExchange exchange, EmptyVisitationContext context) return Rel.newBuilder().setExchange(builder).build(); } - private ExchangeRel.ExchangeTarget toProto(AbstractExchangeRel.ExchangeTarget target) { - ExchangeRel.ExchangeTarget.Builder builder = + private ExchangeRel.ExchangeTarget toProto(final AbstractExchangeRel.ExchangeTarget target) { + final ExchangeRel.ExchangeTarget.Builder builder = ExchangeRel.ExchangeTarget.newBuilder().addAllPartitionId(target.getPartitionIds()); if (target.getType() instanceof TargetType.Uri) { builder.setUri(((TargetType.Uri) target.getType()).getUri()); @@ -625,7 +640,7 @@ private ExchangeRel.ExchangeTarget toProto(AbstractExchangeRel.ExchangeTarget ta return builder.build(); } - UpdateRel.TransformExpression toProto(AbstractUpdate.TransformExpression transformation) { + UpdateRel.TransformExpression toProto(final AbstractUpdate.TransformExpression transformation) { return UpdateRel.TransformExpression.newBuilder() .setTransformation(toProto(transformation.getTransformation())) .setColumnTarget(transformation.getColumnTarget()) @@ -633,19 +648,19 @@ UpdateRel.TransformExpression toProto(AbstractUpdate.TransformExpression transfo } private List toProtoWindowRelFunctions( - Collection + final Collection windowRelFunctionInvocations) { return windowRelFunctionInvocations.stream() .map( f -> { - FunctionArg.FuncArgVisitor< + final FunctionArg.FuncArgVisitor< io.substrait.proto.FunctionArgument, EmptyVisitationContext, RuntimeException> argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter); - List args = f.arguments(); - SimpleExtension.WindowFunctionVariant aggFuncDef = f.declaration(); + final List args = f.arguments(); + final SimpleExtension.WindowFunctionVariant aggFuncDef = f.declaration(); - List arguments = + final List arguments = IntStream.range(0, args.size()) .mapToObj( i -> @@ -653,7 +668,7 @@ private List toProtoWindowRelFun .accept( aggFuncDef, i, argVisitor, EmptyVisitationContext.INSTANCE)) .collect(Collectors.toList()); - List options = + final List options = f.options().stream() .map(ExpressionProtoConverter::from) .collect(Collectors.toList()); @@ -674,8 +689,9 @@ private List toProtoWindowRelFun } @Override - public Rel visit(Project project, EmptyVisitationContext context) throws RuntimeException { - ProjectRel.Builder builder = + public Rel visit(final Project project, final EmptyVisitationContext context) + throws RuntimeException { + final ProjectRel.Builder builder = ProjectRel.newBuilder() .setCommon(common(project)) .setInput(toProto(project.getInput())) @@ -688,8 +704,9 @@ public Rel visit(Project project, EmptyVisitationContext context) throws Runtime } @Override - public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { - ExpandRel.Builder builder = + public Rel visit(final Expand expand, final EmptyVisitationContext context) + throws RuntimeException { + final ExpandRel.Builder builder = ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput())); expand @@ -697,14 +714,14 @@ public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeEx .forEach( expandField -> { if (expandField instanceof Expand.ConsistentField) { - Expand.ConsistentField cf = (Expand.ConsistentField) expandField; + final Expand.ConsistentField cf = (Expand.ConsistentField) expandField; builder.addFields( ExpandRel.ExpandField.newBuilder() .setConsistentField(toProto(cf.getExpression())) .build()); } else if (expandField instanceof Expand.SwitchingField) { - Expand.SwitchingField sf = (Expand.SwitchingField) expandField; + final Expand.SwitchingField sf = (Expand.SwitchingField) expandField; builder.addFields( ExpandRel.ExpandField.newBuilder() .setSwitchingField( @@ -720,8 +737,8 @@ public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeEx } @Override - public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { - SortRel.Builder builder = + public Rel visit(final Sort sort, final EmptyVisitationContext context) throws RuntimeException { + final SortRel.Builder builder = SortRel.newBuilder() .setCommon(common(sort)) .setInput(toProto(sort.getInput())) @@ -733,8 +750,9 @@ public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeExcept } @Override - public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { - CrossRel.Builder builder = + public Rel visit(final Cross cross, final EmptyVisitationContext context) + throws RuntimeException { + final CrossRel.Builder builder = CrossRel.newBuilder() .setCommon(common(cross)) .setLeft(toProto(cross.getLeft())) @@ -747,9 +765,9 @@ public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeExce } @Override - public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) + public Rel visit(final VirtualTableScan virtualTableScan, final EmptyVisitationContext context) throws RuntimeException { - ReadRel.Builder builder = + final ReadRel.Builder builder = ReadRel.newBuilder() .setCommon(common(virtualTableScan)) .setVirtualTable( @@ -772,9 +790,9 @@ public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext conte } @Override - public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) + public Rel visit(final ExtensionLeaf extensionLeaf, final EmptyVisitationContext context) throws RuntimeException { - ExtensionLeafRel.Builder builder = + final ExtensionLeafRel.Builder builder = ExtensionLeafRel.newBuilder() .setCommon(common(extensionLeaf)) .setDetail(extensionLeaf.getDetail().toProto(this)); @@ -782,9 +800,9 @@ public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) } @Override - public Rel visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + public Rel visit(final ExtensionSingle extensionSingle, final EmptyVisitationContext context) throws RuntimeException { - ExtensionSingleRel.Builder builder = + final ExtensionSingleRel.Builder builder = ExtensionSingleRel.newBuilder() .setCommon(common(extensionSingle)) .setInput(toProto(extensionSingle.getInput())) @@ -793,11 +811,11 @@ public Rel visit(ExtensionSingle extensionSingle, EmptyVisitationContext context } @Override - public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + public Rel visit(final ExtensionMulti extensionMulti, final EmptyVisitationContext context) throws RuntimeException { - List inputs = + final List inputs = extensionMulti.getInputs().stream().map(this::toProto).collect(Collectors.toList()); - ExtensionMultiRel.Builder builder = + final ExtensionMultiRel.Builder builder = ExtensionMultiRel.newBuilder() .setCommon(common(extensionMulti)) .addAllInputs(inputs) @@ -805,13 +823,13 @@ public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) return Rel.newBuilder().setExtensionMulti(builder).build(); } - private RelCommon common(io.substrait.relation.Rel rel) { - RelCommon.Builder builder = RelCommon.newBuilder(); + private RelCommon common(final io.substrait.relation.Rel rel) { + final RelCommon.Builder builder = RelCommon.newBuilder(); rel.getCommonExtension() .ifPresent( extension -> builder.setAdvancedExtension(extensionProtoConverter.toProto(extension))); - io.substrait.relation.Rel.Remap remap = rel.getRemap().orElse(null); + final io.substrait.relation.Rel.Remap remap = rel.getRemap().orElse(null); if (remap != null) { builder.setEmit(RelCommon.Emit.newBuilder().addAllOutputMapping(remap.indices())); } else { @@ -819,15 +837,15 @@ private RelCommon common(io.substrait.relation.Rel rel) { } if (rel.getHint().isPresent()) { - io.substrait.hint.Hint hint = rel.getHint().get(); - Hint.Builder hintBuilder = Hint.newBuilder(); + final io.substrait.hint.Hint hint = rel.getHint().get(); + final Hint.Builder hintBuilder = Hint.newBuilder(); hint.getAlias().ifPresent(hintBuilder::setAlias); hintBuilder.addAllOutputNames(hint.getOutputNames()); if (hint.getStats().isPresent()) { - io.substrait.hint.Hint.Stats stats = hint.getStats().get(); - Stats.Builder statsBuilder = Stats.newBuilder(); + final io.substrait.hint.Hint.Stats stats = hint.getStats().get(); + final Stats.Builder statsBuilder = Stats.newBuilder(); stats .getExtension() @@ -838,8 +856,8 @@ private RelCommon common(io.substrait.relation.Rel rel) { } if (hint.getRuntimeConstraint().isPresent()) { - io.substrait.hint.Hint.RuntimeConstraint rc = hint.getRuntimeConstraint().get(); - RuntimeConstraint.Builder rcBuilder = RuntimeConstraint.newBuilder(); + final io.substrait.hint.Hint.RuntimeConstraint rc = hint.getRuntimeConstraint().get(); + final RuntimeConstraint.Builder rcBuilder = RuntimeConstraint.newBuilder(); rc.getExtension() .ifPresent(ae -> rcBuilder.setAdvancedExtension(extensionProtoConverter.toProto(ae))); diff --git a/core/src/main/java/io/substrait/relation/Set.java b/core/src/main/java/io/substrait/relation/Set.java index 0ebdcbddf..a110078e8 100644 --- a/core/src/main/java/io/substrait/relation/Set.java +++ b/core/src/main/java/io/substrait/relation/Set.java @@ -26,7 +26,7 @@ public enum SetOp { private SetRel.SetOp proto; - SetOp(SetRel.SetOp proto) { + SetOp(final SetRel.SetOp proto) { this.proto = proto; } @@ -34,8 +34,8 @@ public SetRel.SetOp toProto() { return proto; } - public static SetOp fromProto(SetRel.SetOp proto) { - for (SetOp v : values()) { + public static SetOp fromProto(final SetRel.SetOp proto) { + for (final SetOp v : values()) { if (v.proto == proto) { return v; } @@ -52,15 +52,15 @@ protected Type.Struct deriveRecordType() { // vs FIXEDCHAR (comes up in Isthmus tests). We also don't recurse into nullability // of the inner fields, in case the type itself is a struct or list or map. - List inputRecordTypes = + final List inputRecordTypes = getInputs().stream().map(Rel::getRecordType).collect(Collectors.toList()); if (inputRecordTypes.isEmpty()) { throw new IllegalArgumentException("Set operation must have at least one input"); } - Type.Struct first = inputRecordTypes.get(0); - List rest = inputRecordTypes.subList(1, inputRecordTypes.size()); + final Type.Struct first = inputRecordTypes.get(0); + final List rest = inputRecordTypes.subList(1, inputRecordTypes.size()); - int numFields = first.fields().size(); + final int numFields = first.fields().size(); if (rest.stream().anyMatch(t -> t.fields().size() != numFields)) { throw new IllegalArgumentException("Set's input records have different number of fields"); } @@ -87,12 +87,13 @@ protected Type.Struct deriveRecordType() { } /** If field is nullable in any of the inputs, it's nullable in the output */ - private Type.Struct coalesceNullabilityUnion(Type.Struct first, List rest) { + private Type.Struct coalesceNullabilityUnion( + final Type.Struct first, final List rest) { - List fields = new ArrayList<>(); + final List fields = new ArrayList<>(); for (int i = 0; i < first.fields().size(); i++) { - Type typeA = first.fields().get(i); - int finalI = i; + final Type typeA = first.fields().get(i); + final int finalI = i; fields.add( rest.stream() .map(struct -> struct.fields().get(finalI)) @@ -107,35 +108,38 @@ private Type.Struct coalesceNullabilityUnion(Type.Struct first, List rest) { + final Type.Struct first, final List rest) { - List fields = new ArrayList<>(); + final List fields = new ArrayList<>(); for (int i = 0; i < first.fields().size(); i++) { - Type typeA = first.fields().get(i); + final Type typeA = first.fields().get(i); if (!typeA.nullable()) { // Just to make this case explicit and to short-circuit, logic below would work without too fields.add(typeA); continue; } - int finalI = i; - boolean anyOtherIsNullable = rest.stream().anyMatch(t -> t.fields().get(finalI).nullable()); + final int finalI = i; + final boolean anyOtherIsNullable = + rest.stream().anyMatch(t -> t.fields().get(finalI).nullable()); fields.add(anyOtherIsNullable ? typeA : TypeCreator.asNotNullable(typeA)); } return Type.Struct.builder().fields(fields).nullable(first.nullable()).build(); } /** If field is required in any of the inputs, it's required in the output */ - private Type.Struct coalesceNullabilityIntersection(Type.Struct first, List rest) { - List fields = new ArrayList<>(); + private Type.Struct coalesceNullabilityIntersection( + final Type.Struct first, final List rest) { + final List fields = new ArrayList<>(); for (int i = 0; i < first.fields().size(); i++) { - Type typeA = first.fields().get(i); + final Type typeA = first.fields().get(i); if (!typeA.nullable()) { // Just to make this case explicit and to short-circuit, logic below would work without too fields.add(typeA); continue; } - int finalI = i; - boolean anyOtherIsRequired = rest.stream().anyMatch(t -> !t.fields().get(finalI).nullable()); + final int finalI = i; + final boolean anyOtherIsRequired = + rest.stream().anyMatch(t -> !t.fields().get(finalI).nullable()); fields.add(anyOtherIsRequired ? TypeCreator.asNotNullable(typeA) : typeA); } return Type.Struct.builder().fields(fields).nullable(first.nullable()).build(); @@ -143,7 +147,7 @@ private Type.Struct coalesceNullabilityIntersection(Type.Struct first, List O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/Sort.java b/core/src/main/java/io/substrait/relation/Sort.java index 39c7f5a1a..a5adb9d7a 100644 --- a/core/src/main/java/io/substrait/relation/Sort.java +++ b/core/src/main/java/io/substrait/relation/Sort.java @@ -18,7 +18,7 @@ protected Type.Struct deriveRecordType() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 9a54f6794..fccfac159 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -25,11 +25,11 @@ public abstract class VirtualTableScan extends AbstractReadRel { */ @Value.Check protected void check() { - List names = getInitialSchema().names(); + final List names = getInitialSchema().names(); assert names.size() == NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct()); - List rows = getRows(); + final List rows = getRows(); assert rows.size() > 0 && names.stream().noneMatch(s -> s == null) @@ -37,7 +37,7 @@ protected void check() { && rows.stream() .allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size()); - for (Expression.StructLiteral row : rows) { + for (final Expression.StructLiteral row : rows) { validateRowConformsToSchema(row); } } @@ -48,10 +48,10 @@ protected void check() { * @param row the row to validate * @throws AssertionError if the row does not conform to the schema */ - private void validateRowConformsToSchema(Expression.StructLiteral row) { - Type.Struct schemaStruct = getInitialSchema().struct(); - List schemaFieldTypes = schemaStruct.fields(); - List rowFields = row.fields(); + private void validateRowConformsToSchema(final Expression.StructLiteral row) { + final Type.Struct schemaStruct = getInitialSchema().struct(); + final List schemaFieldTypes = schemaStruct.fields(); + final List rowFields = row.fields(); assert rowFields.size() == schemaFieldTypes.size() : String.format( @@ -59,8 +59,8 @@ private void validateRowConformsToSchema(Expression.StructLiteral row) { rowFields.size(), schemaFieldTypes.size()); for (int i = 0; i < rowFields.size(); i++) { - Type rowFieldType = rowFields.get(i).getType(); - Type schemaFieldType = schemaFieldTypes.get(i); + final Type rowFieldType = rowFields.get(i).getType(); + final Type schemaFieldType = schemaFieldTypes.get(i); assert rowFieldType.equals(schemaFieldType) : String.format( @@ -71,7 +71,7 @@ private void validateRowConformsToSchema(Expression.StructLiteral row) { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } @@ -85,7 +85,7 @@ private static class NamedFieldCountingTypeVisitor private static final NamedFieldCountingTypeVisitor VISITOR = new NamedFieldCountingTypeVisitor(); - private static Integer countNames(Type type) { + private static Integer countNames(final Type type) { return type.accept(VISITOR); } @@ -210,24 +210,24 @@ public Integer visit(Type.Decimal type) throws RuntimeException { } @Override - public Integer visit(Type.Struct type) throws RuntimeException { + public Integer visit(final Type.Struct type) throws RuntimeException { // Only struct fields have names - the top level column names are also // captured by this since the whole schema is wrapped in a Struct type return type.fields().stream().mapToInt(field -> 1 + field.accept(this)).sum(); } @Override - public Integer visit(Type.ListType type) throws RuntimeException { + public Integer visit(final Type.ListType type) throws RuntimeException { return type.elementType().accept(this); } @Override - public Integer visit(Type.Map type) throws RuntimeException { + public Integer visit(final Type.Map type) throws RuntimeException { return type.key().accept(this) + type.value().accept(this); } @Override - public Integer visit(Type.UserDefined type) throws RuntimeException { + public Integer visit(final Type.UserDefined type) throws RuntimeException { return 0; } } diff --git a/core/src/main/java/io/substrait/relation/extensions/EmptyDetail.java b/core/src/main/java/io/substrait/relation/extensions/EmptyDetail.java index 552532d5c..982fea550 100644 --- a/core/src/main/java/io/substrait/relation/extensions/EmptyDetail.java +++ b/core/src/main/java/io/substrait/relation/extensions/EmptyDetail.java @@ -23,7 +23,7 @@ public class EmptyDetail Extension.DdlExtensionObject { @Override - public Any toProto(RelProtoConverter converter) { + public Any toProto(final RelProtoConverter converter) { return com.google.protobuf.Any.pack(com.google.protobuf.Empty.getDefaultInstance()); } @@ -33,12 +33,12 @@ public Type.Struct deriveRecordType() { } @Override - public Type.Struct deriveRecordType(Rel input) { + public Type.Struct deriveRecordType(final Rel input) { return TypeCreator.NULLABLE.struct(); } @Override - public Type.Struct deriveRecordType(List inputs) { + public Type.Struct deriveRecordType(final List inputs) { return TypeCreator.NULLABLE.struct(); } @@ -48,7 +48,7 @@ public NamedStruct deriveSchema() { } @Override - public boolean equals(Object o) { + public boolean equals(final Object o) { return o instanceof EmptyDetail; } diff --git a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java index 4b227025f..ee05eea5e 100644 --- a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java +++ b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java @@ -31,7 +31,8 @@ static ImmutableFileOrFiles.Builder builder() { } default ReadRel.LocalFiles.FileOrFiles toProto() { - ReadRel.LocalFiles.FileOrFiles.Builder builder = ReadRel.LocalFiles.FileOrFiles.newBuilder(); + final ReadRel.LocalFiles.FileOrFiles.Builder builder = + ReadRel.LocalFiles.FileOrFiles.newBuilder(); getFileFormat() .ifPresent( @@ -48,9 +49,9 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { builder.setDwrf( ReadRel.LocalFiles.FileOrFiles.DwrfReadOptions.newBuilder().build()); } else if (fileFormat instanceof FileFormat.DelimiterSeparatedTextReadOptions) { - FileFormat.DelimiterSeparatedTextReadOptions options = + final FileFormat.DelimiterSeparatedTextReadOptions options = (FileFormat.DelimiterSeparatedTextReadOptions) fileFormat; - ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.Builder + final ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.Builder optionsBuilder = ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions .newBuilder() @@ -62,7 +63,7 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { options.getValueTreatedAsNull().ifPresent(optionsBuilder::setValueTreatedAsNull); builder.setText(optionsBuilder.build()); } else if (fileFormat instanceof FileFormat.Extension) { - FileFormat.Extension options = (FileFormat.Extension) fileFormat; + final FileFormat.Extension options = (FileFormat.Extension) fileFormat; builder.setExtension(options.getExtension()); } else { throw new UnsupportedOperationException( diff --git a/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java b/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java index 10dc1e532..c0e19d695 100644 --- a/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java +++ b/core/src/main/java/io/substrait/relation/physical/BroadcastExchange.java @@ -8,7 +8,7 @@ public abstract class BroadcastExchange extends AbstractExchangeRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index ca49bcd46..031531814 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -38,12 +38,12 @@ public enum JoinType { private HashJoinRel.JoinType proto; - JoinType(HashJoinRel.JoinType proto) { + JoinType(final HashJoinRel.JoinType proto) { this.proto = proto; } - public static JoinType fromProto(HashJoinRel.JoinType proto) { - for (JoinType v : values()) { + public static JoinType fromProto(final HashJoinRel.JoinType proto) { + for (final JoinType v : values()) { if (v.proto == proto) { return v; } @@ -58,8 +58,8 @@ public HashJoinRel.JoinType toProto() { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = getLeftTypes(); - Stream rightTypes = getRightTypes(); + final Stream leftTypes = getLeftTypes(); + final Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } @@ -91,7 +91,7 @@ private Stream getRightTypes() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java index 9664868bb..08cff4706 100644 --- a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java @@ -38,12 +38,12 @@ public enum JoinType { private MergeJoinRel.JoinType proto; - JoinType(MergeJoinRel.JoinType proto) { + JoinType(final MergeJoinRel.JoinType proto) { this.proto = proto; } - public static JoinType fromProto(MergeJoinRel.JoinType proto) { - for (JoinType v : values()) { + public static JoinType fromProto(final MergeJoinRel.JoinType proto) { + for (final JoinType v : values()) { if (v.proto == proto) { return v; } @@ -58,8 +58,8 @@ public MergeJoinRel.JoinType toProto() { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = getLeftTypes(); - Stream rightTypes = getRightTypes(); + final Stream leftTypes = getLeftTypes(); + final Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } @@ -91,7 +91,7 @@ private Stream getRightTypes() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java b/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java index 9f1c2f31e..e4a3cd314 100644 --- a/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java +++ b/core/src/main/java/io/substrait/relation/physical/MultiBucketExchange.java @@ -13,7 +13,7 @@ public abstract class MultiBucketExchange extends AbstractExchangeRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java index 25c0a509d..18f95ee29 100644 --- a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -31,7 +31,7 @@ public enum JoinType { private NestedLoopJoinRel.JoinType proto; - JoinType(NestedLoopJoinRel.JoinType proto) { + JoinType(final NestedLoopJoinRel.JoinType proto) { this.proto = proto; } @@ -39,8 +39,8 @@ public NestedLoopJoinRel.JoinType toProto() { return proto; } - public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { - for (JoinType v : values()) { + public static JoinType fromProto(final NestedLoopJoinRel.JoinType proto) { + for (final JoinType v : values()) { if (v.proto == proto) { return v; } @@ -52,8 +52,8 @@ public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) { @Override protected Type.Struct deriveRecordType() { - Stream leftTypes = getLeftTypes(); - Stream rightTypes = getRightTypes(); + final Stream leftTypes = getLeftTypes(); + final Stream rightTypes = getRightTypes(); return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } @@ -85,7 +85,7 @@ private Stream getRightTypes() { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java b/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java index 3bbb3e370..d72d29c0e 100644 --- a/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java +++ b/core/src/main/java/io/substrait/relation/physical/RoundRobinExchange.java @@ -10,7 +10,7 @@ public abstract class RoundRobinExchange extends AbstractExchangeRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java b/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java index 6f8f99977..787661a1f 100644 --- a/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java +++ b/core/src/main/java/io/substrait/relation/physical/ScatterExchange.java @@ -12,7 +12,7 @@ public abstract class ScatterExchange extends AbstractExchangeRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java b/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java index 6446d8389..b0e211a6f 100644 --- a/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java +++ b/core/src/main/java/io/substrait/relation/physical/SingleBucketExchange.java @@ -11,7 +11,7 @@ public abstract class SingleBucketExchange extends AbstractExchangeRel { @Override public O accept( - RelVisitor visitor, C context) throws E { + final RelVisitor visitor, final C context) throws E { return visitor.visit(this, context); } diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 13936dc07..ee087631c 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -35,7 +35,8 @@ public static class ParseDeserializer extends StdDeserializer { private final BiFunction converter; public ParseDeserializer( - Class clazz, BiFunction converter) { + final Class clazz, + final BiFunction converter) { super(clazz); this.converter = converter; } @@ -43,11 +44,12 @@ public ParseDeserializer( @Override public T deserialize(final JsonParser p, final DeserializationContext ctxt) throws IOException, JsonProcessingException { - String typeString = p.getValueAsString(); + final String typeString = p.getValueAsString(); try { - String urn = (String) ctxt.findInjectableValue(SimpleExtension.URN_LOCATOR_KEY, null, null); + final String urn = + (String) ctxt.findInjectableValue(SimpleExtension.URN_LOCATOR_KEY, null, null); return TypeStringParser.parse(typeString, urn, converter); - } catch (Exception ex) { + } catch (final Exception ex) { throw JsonMappingException.from( p, "Unable to parse string " + typeString.replace("\n", " \\n"), ex); } diff --git a/core/src/main/java/io/substrait/type/NamedStruct.java b/core/src/main/java/io/substrait/type/NamedStruct.java index 9c241542c..cbf56930f 100644 --- a/core/src/main/java/io/substrait/type/NamedStruct.java +++ b/core/src/main/java/io/substrait/type/NamedStruct.java @@ -15,12 +15,12 @@ static ImmutableNamedStruct.Builder builder() { return ImmutableNamedStruct.builder(); } - static NamedStruct of(Iterable names, Type.Struct type) { + static NamedStruct of(final Iterable names, final Type.Struct type) { return ImmutableNamedStruct.builder().addAllNames(names).struct(type).build(); } - default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConverter) { - io.substrait.proto.Type type = struct().accept(typeProtoConverter); + default io.substrait.proto.NamedStruct toProto(final TypeProtoConverter typeProtoConverter) { + final io.substrait.proto.Type type = struct().accept(typeProtoConverter); return io.substrait.proto.NamedStruct.newBuilder() .setStruct(type.getStruct()) .addAllNames(names()) @@ -28,8 +28,9 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve } static io.substrait.type.NamedStruct fromProto( - io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) { - io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); + final io.substrait.proto.NamedStruct namedStruct, + final ProtoTypeConverter protoTypeConverter) { + final io.substrait.proto.Type.Struct struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) .struct( diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index d7c196148..619b723df 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -4,150 +4,150 @@ public class StringTypeVisitor implements TypeVisitor { - private String n(Type type) { + private String n(final Type type) { return type.nullable() ? "?" : ""; } @Override - public String visit(Type.Bool type) throws RuntimeException { + public String visit(final Type.Bool type) throws RuntimeException { return "boolean" + n(type); } @Override - public String visit(Type.I8 type) throws RuntimeException { + public String visit(final Type.I8 type) throws RuntimeException { return "i8" + n(type); } @Override - public String visit(Type.I16 type) throws RuntimeException { + public String visit(final Type.I16 type) throws RuntimeException { return "i16" + n(type); } @Override - public String visit(Type.I32 type) throws RuntimeException { + public String visit(final Type.I32 type) throws RuntimeException { return "i32" + n(type); } @Override - public String visit(Type.I64 type) throws RuntimeException { + public String visit(final Type.I64 type) throws RuntimeException { return "i64" + n(type); } @Override - public String visit(Type.FP32 type) throws RuntimeException { + public String visit(final Type.FP32 type) throws RuntimeException { return "fp32" + n(type); } @Override - public String visit(Type.FP64 type) throws RuntimeException { + public String visit(final Type.FP64 type) throws RuntimeException { return "fp64" + n(type); } @Override - public String visit(Type.Str type) throws RuntimeException { + public String visit(final Type.Str type) throws RuntimeException { return "string" + n(type); } @Override - public String visit(Type.Binary type) throws RuntimeException { + public String visit(final Type.Binary type) throws RuntimeException { return "binary" + n(type); } @Override - public String visit(Type.Date type) throws RuntimeException { + public String visit(final Type.Date type) throws RuntimeException { return "date" + n(type); } @Override - public String visit(Type.Time type) throws RuntimeException { + public String visit(final Type.Time type) throws RuntimeException { return "time" + n(type); } @Override - public String visit(Type.TimestampTZ type) throws RuntimeException { + public String visit(final Type.TimestampTZ type) throws RuntimeException { return "timestamp_tz" + n(type); } @Override - public String visit(Type.Timestamp type) throws RuntimeException { + public String visit(final Type.Timestamp type) throws RuntimeException { return "timestamp" + n(type); } @Override - public String visit(Type.IntervalYear type) throws RuntimeException { + public String visit(final Type.IntervalYear type) throws RuntimeException { return "interval_year" + n(type); } @Override - public String visit(Type.IntervalDay type) throws RuntimeException { + public String visit(final Type.IntervalDay type) throws RuntimeException { return "interval_day" + n(type); } @Override - public String visit(Type.IntervalCompound type) throws RuntimeException { + public String visit(final Type.IntervalCompound type) throws RuntimeException { return "interval_compound" + n(type); } @Override - public String visit(Type.UUID type) throws RuntimeException { + public String visit(final Type.UUID type) throws RuntimeException { return "uuid" + n(type); } @Override - public String visit(Type.FixedChar type) throws RuntimeException { + public String visit(final Type.FixedChar type) throws RuntimeException { return String.format("char<%d>%s", type.length(), n(type)); } @Override - public String visit(Type.VarChar type) throws RuntimeException { + public String visit(final Type.VarChar type) throws RuntimeException { return String.format("varchar<%d>%s", type.length(), n(type)); } @Override - public String visit(Type.FixedBinary type) throws RuntimeException { + public String visit(final Type.FixedBinary type) throws RuntimeException { return String.format("fixedbinary<%d>%s", type.length(), n(type)); } @Override - public String visit(Type.Decimal type) throws RuntimeException { + public String visit(final Type.Decimal type) throws RuntimeException { return String.format("decimal<%d,%d>%s", type.precision(), type.scale(), n(type)); } @Override - public String visit(Type.PrecisionTime type) throws RuntimeException { + public String visit(final Type.PrecisionTime type) throws RuntimeException { return String.format("precision_time<%d>%s", type.precision(), n(type)); } @Override - public String visit(Type.PrecisionTimestamp type) throws RuntimeException { + public String visit(final Type.PrecisionTimestamp type) throws RuntimeException { return String.format("precision_timestamp<%d>%s", type.precision(), n(type)); } @Override - public String visit(Type.PrecisionTimestampTZ type) throws RuntimeException { + public String visit(final Type.PrecisionTimestampTZ type) throws RuntimeException { return String.format("precision_timestamp_tz<%d>%s", type.precision(), n(type)); } @Override - public String visit(Type.Struct type) throws RuntimeException { + public String visit(final Type.Struct type) throws RuntimeException { return String.format( "struct<%s>%s", type.fields().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")), n(type)); } @Override - public String visit(Type.ListType type) throws RuntimeException { + public String visit(final Type.ListType type) throws RuntimeException { return String.format("list<%s>%s", type.elementType().accept(this), n(type)); } @Override - public String visit(Type.Map type) throws RuntimeException { + public String visit(final Type.Map type) throws RuntimeException { return String.format( "map<%s,%s>%s", type.key().accept(this), type.value().accept(this), n(type)); } @Override - public String visit(Type.UserDefined type) throws RuntimeException { + public String visit(final Type.UserDefined type) throws RuntimeException { return String.format("u!%s%s", type.name(), n(type)); } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..bd16d4d44 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -11,16 +11,19 @@ @Value.Enclosing public interface Type extends TypeExpression, ParameterizedType, NullableType, FunctionArg { - static TypeCreator withNullability(boolean nullable) { + static TypeCreator withNullability(final boolean nullable) { return nullable ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; } @Override - R accept(final TypeVisitor typeVisitor) throws E; + R accept(TypeVisitor typeVisitor) throws E; @Override default R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + final SimpleExtension.Function fnDef, + final int argIdx, + final FuncArgVisitor fnArgVisitor, + final C context) throws E { return fnArgVisitor.visitType(fnDef, argIdx, this, context); } @@ -398,7 +401,7 @@ public static ImmutableType.UserDefined.Builder builder() { } @Override - public R accept(TypeVisitor typeVisitor) throws E { + public R accept(final TypeVisitor typeVisitor) throws E { return typeVisitor.visit(this); } } diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 43358e505..a63e7fa8b 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -30,7 +30,7 @@ public class TypeCreator { private static NullableSettingTypeVisitor NULLABLE_FALSE_VISITOR = new NullableSettingTypeVisitor(false); - protected TypeCreator(boolean nullable) { + protected TypeCreator(final boolean nullable) { this.nullable = nullable; BOOLEAN = Type.Bool.builder().nullable(nullable).build(); I8 = Type.I8.builder().nullable(nullable).build(); @@ -49,78 +49,78 @@ protected TypeCreator(boolean nullable) { UUID = Type.UUID.builder().nullable(nullable).build(); } - public Type fixedChar(int len) { + public Type fixedChar(final int len) { return Type.FixedChar.builder().nullable(nullable).length(len).build(); } - public final Type varChar(int len) { + public final Type varChar(final int len) { return Type.VarChar.builder().nullable(nullable).length(len).build(); } - public final Type fixedBinary(int len) { + public final Type fixedBinary(final int len) { return Type.FixedBinary.builder().nullable(nullable).length(len).build(); } - public final Type decimal(int precision, int scale) { + public final Type decimal(final int precision, final int scale) { return Type.Decimal.builder().nullable(nullable).precision(precision).scale(scale).build(); } - public final Type.Struct struct(Type... types) { + public final Type.Struct struct(final Type... types) { return Type.Struct.builder().nullable(nullable).addFields(types).build(); } - public final Type precisionTime(int precision) { + public final Type precisionTime(final int precision) { return Type.PrecisionTime.builder().nullable(nullable).precision(precision).build(); } - public final Type precisionTimestamp(int precision) { + public final Type precisionTimestamp(final int precision) { return Type.PrecisionTimestamp.builder().nullable(nullable).precision(precision).build(); } - public final Type precisionTimestampTZ(int precision) { + public final Type precisionTimestampTZ(final int precision) { return Type.PrecisionTimestampTZ.builder().nullable(nullable).precision(precision).build(); } - public final Type intervalDay(int precision) { + public final Type intervalDay(final int precision) { return Type.IntervalDay.builder().nullable(nullable).precision(precision).build(); } - public final Type intervalCompound(int precision) { + public final Type intervalCompound(final int precision) { return Type.IntervalCompound.builder().nullable(nullable).precision(precision).build(); } - public Type.Struct struct(Iterable types) { + public Type.Struct struct(final Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } - public Type.Struct struct(Stream types) { + public Type.Struct struct(final Stream types) { return Type.Struct.builder() .nullable(nullable) .addAllFields(types.collect(Collectors.toList())) .build(); } - public Type.ListType list(Type type) { + public Type.ListType list(final Type type) { return Type.ListType.builder().nullable(nullable).elementType(type).build(); } - public Type.Map map(Type key, Type value) { + public Type.Map map(final Type key, final Type value) { return Type.Map.builder().nullable(nullable).key(key).value(value).build(); } - public Type userDefined(String urn, String name) { + public Type userDefined(final String urn, final String name) { return Type.UserDefined.builder().nullable(nullable).urn(urn).name(name).build(); } - public static TypeCreator of(boolean nullability) { + public static TypeCreator of(final boolean nullability) { return nullability ? NULLABLE : REQUIRED; } - public static Type asNullable(Type type) { + public static Type asNullable(final Type type) { return type.nullable() ? type : type.accept(NULLABLE_TRUE_VISITOR); } - public static Type asNotNullable(Type type) { + public static Type asNotNullable(final Type type) { return type.nullable() ? type.accept(NULLABLE_FALSE_VISITOR) : type; } @@ -129,147 +129,147 @@ private static final class NullableSettingTypeVisitor private final boolean nullability; - NullableSettingTypeVisitor(boolean nullability) { + NullableSettingTypeVisitor(final boolean nullability) { this.nullability = nullability; } @Override - public Type visit(Type.Bool type) throws RuntimeException { + public Type visit(final Type.Bool type) throws RuntimeException { return Type.Bool.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.I8 type) throws RuntimeException { + public Type visit(final Type.I8 type) throws RuntimeException { return Type.I8.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.I16 type) throws RuntimeException { + public Type visit(final Type.I16 type) throws RuntimeException { return Type.I16.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.I32 type) throws RuntimeException { + public Type visit(final Type.I32 type) throws RuntimeException { return Type.I32.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.I64 type) throws RuntimeException { + public Type visit(final Type.I64 type) throws RuntimeException { return Type.I64.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.FP32 type) throws RuntimeException { + public Type visit(final Type.FP32 type) throws RuntimeException { return Type.FP32.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.FP64 type) throws RuntimeException { + public Type visit(final Type.FP64 type) throws RuntimeException { return Type.FP64.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Str type) throws RuntimeException { + public Type visit(final Type.Str type) throws RuntimeException { return Type.Str.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Binary type) throws RuntimeException { + public Type visit(final Type.Binary type) throws RuntimeException { return Type.Binary.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Date type) throws RuntimeException { + public Type visit(final Type.Date type) throws RuntimeException { return Type.Date.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Time type) throws RuntimeException { + public Type visit(final Type.Time type) throws RuntimeException { return Type.Time.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.TimestampTZ type) throws RuntimeException { + public Type visit(final Type.TimestampTZ type) throws RuntimeException { return Type.TimestampTZ.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Timestamp type) throws RuntimeException { + public Type visit(final Type.Timestamp type) throws RuntimeException { return Type.Timestamp.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.IntervalYear type) throws RuntimeException { + public Type visit(final Type.IntervalYear type) throws RuntimeException { return Type.IntervalYear.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.IntervalDay type) throws RuntimeException { + public Type visit(final Type.IntervalDay type) throws RuntimeException { return Type.IntervalDay.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.IntervalCompound type) throws RuntimeException { + public Type visit(final Type.IntervalCompound type) throws RuntimeException { return Type.IntervalCompound.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.UUID type) throws RuntimeException { + public Type visit(final Type.UUID type) throws RuntimeException { return Type.UUID.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.FixedChar type) throws RuntimeException { + public Type visit(final Type.FixedChar type) throws RuntimeException { return Type.FixedChar.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.VarChar type) throws RuntimeException { + public Type visit(final Type.VarChar type) throws RuntimeException { return Type.VarChar.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.FixedBinary type) throws RuntimeException { + public Type visit(final Type.FixedBinary type) throws RuntimeException { return Type.FixedBinary.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Decimal type) throws RuntimeException { + public Type visit(final Type.Decimal type) throws RuntimeException { return Type.Decimal.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.PrecisionTime type) throws RuntimeException { + public Type visit(final Type.PrecisionTime type) throws RuntimeException { return Type.PrecisionTime.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.PrecisionTimestamp type) throws RuntimeException { + public Type visit(final Type.PrecisionTimestamp type) throws RuntimeException { return Type.PrecisionTimestamp.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.PrecisionTimestampTZ type) throws RuntimeException { + public Type visit(final Type.PrecisionTimestampTZ type) throws RuntimeException { return Type.PrecisionTimestampTZ.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Struct type) throws RuntimeException { + public Type visit(final Type.Struct type) throws RuntimeException { return Type.Struct.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.ListType type) throws RuntimeException { + public Type visit(final Type.ListType type) throws RuntimeException { return Type.ListType.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.Map type) throws RuntimeException { + public Type visit(final Type.Map type) throws RuntimeException { return Type.Map.builder().from(type).nullable(nullability).build(); } @Override - public Type visit(Type.UserDefined type) throws RuntimeException { + public Type visit(final Type.UserDefined type) throws RuntimeException { return Type.UserDefined.builder().from(type).nullable(nullability).build(); } } diff --git a/core/src/main/java/io/substrait/type/TypeExpressionEvaluator.java b/core/src/main/java/io/substrait/type/TypeExpressionEvaluator.java index 5be19898b..58299388a 100644 --- a/core/src/main/java/io/substrait/type/TypeExpressionEvaluator.java +++ b/core/src/main/java/io/substrait/type/TypeExpressionEvaluator.java @@ -7,9 +7,9 @@ public class TypeExpressionEvaluator { public static Type evaluateExpression( - TypeExpression returnExpression, - List parameterizedTypeList, - List actualTypes) { + final TypeExpression returnExpression, + final List parameterizedTypeList, + final List actualTypes) { if (returnExpression instanceof Type) { return (Type) returnExpression; diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index ce6a08910..b65ef0cda 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -63,7 +63,7 @@ abstract class TypeThrowsVisitor implements TypeVisitor< private final String unsupportedMessage; - protected TypeThrowsVisitor(String unsupportedMessage) { + protected TypeThrowsVisitor(final String unsupportedMessage) { this.unsupportedMessage = unsupportedMessage; } diff --git a/core/src/main/java/io/substrait/type/YamlRead.java b/core/src/main/java/io/substrait/type/YamlRead.java index 23123bd92..1f3f57370 100644 --- a/core/src/main/java/io/substrait/type/YamlRead.java +++ b/core/src/main/java/io/substrait/type/YamlRead.java @@ -29,14 +29,14 @@ public class YamlRead { "datetime", "string")); - public static void main(String[] args) throws Exception { + public static void main(final String[] args) throws Exception { try { System.out.println( "Read: " + YamlRead.class.getResource(DefaultExtensionCatalog.FUNCTIONS_BOOLEAN)); - List signatures = loadFunctions(); + final List signatures = loadFunctions(); signatures.forEach(f -> System.out.println(f.key())); - } catch (Exception ex) { + } catch (final Exception ex) { throw ex; } } @@ -48,18 +48,18 @@ public static List loadFunctions() { .collect(java.util.stream.Collectors.toList())); } - public static List loadFunctions(List files) { + public static List loadFunctions(final List files) { return files.stream().flatMap(YamlRead::parse).collect(Collectors.toList()); } - private static Stream parse(String name) { + private static Stream parse(final String name) { try { - ObjectMapper mapper = + final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()) .enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY) .registerModule(Deserializers.MODULE); - SimpleExtension.ExtensionSignatures doc = + final SimpleExtension.ExtensionSignatures doc = mapper.readValue(new File(name), SimpleExtension.ExtensionSignatures.class); LOGGER.atDebug().log( @@ -69,9 +69,9 @@ private static Stream parse(String name) { name); return doc.resolve(name); - } catch (RuntimeException ex) { + } catch (final RuntimeException ex) { throw ex; - } catch (Exception ex) { + } catch (final Exception ex) { throw new IllegalStateException("Failure while parsing file " + name, ex); } } diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index ec96875bf..447bebe9b 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -21,17 +21,18 @@ public class ParseToPojo { - public static Type type(String urn, SubstraitTypeParser.StartContext ctx) { - Visitor visitor = Visitor.simple(urn); + public static Type type(final String urn, final SubstraitTypeParser.StartContext ctx) { + final Visitor visitor = Visitor.simple(urn); return (Type) ctx.accept(visitor); } public static ParameterizedType parameterizedType( - String urn, SubstraitTypeParser.StartContext ctx) { + final String urn, final SubstraitTypeParser.StartContext ctx) { return (ParameterizedType) ctx.accept(Visitor.parameterized(urn)); } - public static TypeExpression typeExpression(String urn, SubstraitTypeParser.StartContext ctx) { + public static TypeExpression typeExpression( + final String urn, final SubstraitTypeParser.StartContext ctx) { return ctx.accept(Visitor.expression(urn)); } @@ -39,19 +40,19 @@ public static class Visitor implements SubstraitTypeVisitor { private final VisitorType expressionType; private final String urn; - public static Visitor simple(String urn) { + public static Visitor simple(final String urn) { return new Visitor(VisitorType.SIMPLE, urn); } - public static Visitor parameterized(String urn) { + public static Visitor parameterized(final String urn) { return new Visitor(VisitorType.PARAMETERIZED, urn); } - public static Visitor expression(String urn) { + public static Visitor expression(final String urn) { return new Visitor(VisitorType.EXPRESSION, urn); } - private Visitor(VisitorType exprType, String urn) { + private Visitor(final VisitorType exprType, final String urn) { this.expressionType = exprType; this.urn = urn; } @@ -159,8 +160,8 @@ public Type visitIntervalYear(final SubstraitTypeParser.IntervalYearContext ctx) @Override public TypeExpression visitIntervalDay(final SubstraitTypeParser.IntervalDayContext ctx) { - boolean nullable = ctx.isnull != null; - Object precision = i(ctx.precision); + final boolean nullable = ctx.isnull != null; + final Object precision = i(ctx.precision); if (precision instanceof Integer) { return withNull(nullable).intervalDay((Integer) precision); } @@ -176,8 +177,8 @@ public TypeExpression visitIntervalDay(final SubstraitTypeParser.IntervalDayCont @Override public TypeExpression visitIntervalCompound( final SubstraitTypeParser.IntervalCompoundContext ctx) { - boolean nullable = ctx.isnull != null; - Object precision = i(ctx.precision); + final boolean nullable = ctx.isnull != null; + final Object precision = i(ctx.precision); if (precision instanceof Integer) { return withNull(nullable).intervalCompound((Integer) precision); } @@ -196,14 +197,14 @@ public Type visitUuid(final SubstraitTypeParser.UuidContext ctx) { } @Override - public Type visitUserDefined(SubstraitTypeParser.UserDefinedContext ctx) { - String name = ctx.Identifier().getSymbol().getText(); + public Type visitUserDefined(final SubstraitTypeParser.UserDefinedContext ctx) { + final String name = ctx.Identifier().getSymbol().getText(); return withNull(ctx).userDefined(urn, name); } @Override public TypeExpression visitFixedChar(final SubstraitTypeParser.FixedCharContext ctx) { - boolean nullable = ctx.isnull != null; + final boolean nullable = ctx.isnull != null; return of( ctx.len, withNull(nullable)::fixedChar, @@ -212,11 +213,11 @@ public TypeExpression visitFixedChar(final SubstraitTypeParser.FixedCharContext } private TypeExpression of( - SubstraitTypeParser.NumericParameterContext ctx, - IntFunction intFunc, - Function strFunc, - Function exprFunc) { - TypeExpression type = ctx.accept(this); + final SubstraitTypeParser.NumericParameterContext ctx, + final IntFunction intFunc, + final Function strFunc, + final Function exprFunc) { + final TypeExpression type = ctx.accept(this); if (type instanceof TypeExpression.IntegerLiteral) { return intFunc.apply(((TypeExpression.IntegerLiteral) type).value()); } @@ -230,7 +231,7 @@ private TypeExpression of( @Override public TypeExpression visitVarChar(final SubstraitTypeParser.VarCharContext ctx) { - boolean nullable = ctx.isnull != null; + final boolean nullable = ctx.isnull != null; return of( ctx.len, withNull(nullable)::varChar, @@ -240,7 +241,7 @@ public TypeExpression visitVarChar(final SubstraitTypeParser.VarCharContext ctx) @Override public TypeExpression visitFixedBinary(final SubstraitTypeParser.FixedBinaryContext ctx) { - boolean nullable = ctx.isnull != null; + final boolean nullable = ctx.isnull != null; return of( ctx.len, withNull(nullable)::fixedBinary, @@ -250,9 +251,9 @@ public TypeExpression visitFixedBinary(final SubstraitTypeParser.FixedBinaryCont @Override public TypeExpression visitDecimal(final SubstraitTypeParser.DecimalContext ctx) { - boolean nullable = ctx.isnull != null; - Object precision = i(ctx.precision); - Object scale = i(ctx.scale); + final boolean nullable = ctx.isnull != null; + final Object precision = i(ctx.precision); + final Object scale = i(ctx.scale); if (precision instanceof Integer && scale instanceof Integer) { return withNull(nullable).decimal((int) precision, (int) scale); } @@ -279,8 +280,8 @@ public TypeExpression visitDecimal(final SubstraitTypeParser.DecimalContext ctx) @Override public TypeExpression visitPrecisionTimestamp( final SubstraitTypeParser.PrecisionTimestampContext ctx) { - boolean nullable = ctx.isnull != null; - Object precision = i(ctx.precision); + final boolean nullable = ctx.isnull != null; + final Object precision = i(ctx.precision); if (precision instanceof Integer) { return withNull(nullable).precisionTimestamp((Integer) precision); } @@ -296,8 +297,8 @@ public TypeExpression visitPrecisionTimestamp( @Override public TypeExpression visitPrecisionTimestampTZ( final SubstraitTypeParser.PrecisionTimestampTZContext ctx) { - boolean nullable = ctx.isnull != null; - Object precision = i(ctx.precision); + final boolean nullable = ctx.isnull != null; + final Object precision = i(ctx.precision); if (precision instanceof Integer) { return withNull(nullable).precisionTimestampTZ((Integer) precision); } @@ -310,8 +311,8 @@ public TypeExpression visitPrecisionTimestampTZ( return withNullE(nullable).precisionTimestampTZE(ctx.precision.accept(this)); } - private Object i(SubstraitTypeParser.NumericParameterContext ctx) { - TypeExpression type = ctx.accept(this); + private Object i(final SubstraitTypeParser.NumericParameterContext ctx) { + final TypeExpression type = ctx.accept(this); if (type instanceof TypeExpression.IntegerLiteral) { return ((TypeExpression.IntegerLiteral) type).value(); } else if (type instanceof ParameterizedType.StringLiteral) { @@ -325,8 +326,8 @@ private Object i(SubstraitTypeParser.NumericParameterContext ctx) { @Override public TypeExpression visitStruct(final SubstraitTypeParser.StructContext ctx) { - boolean nullable = ctx.isnull != null; - List types = + final boolean nullable = ctx.isnull != null; + final List types = ctx.expr().stream() .map(t -> t.accept(this)) .collect(java.util.stream.Collectors.toList()); @@ -356,8 +357,8 @@ public TypeExpression visitNStruct(final SubstraitTypeParser.NStructContext ctx) @Override public TypeExpression visitList(final SubstraitTypeParser.ListContext ctx) { - boolean nullable = ctx.isnull != null; - TypeExpression element = ctx.expr().accept(this); + final boolean nullable = ctx.isnull != null; + final TypeExpression element = ctx.expr().accept(this); if (element instanceof Type) { return withNull(nullable).list((Type) element); } @@ -373,9 +374,9 @@ public TypeExpression visitList(final SubstraitTypeParser.ListContext ctx) { @Override public TypeExpression visitMap(final SubstraitTypeParser.MapContext ctx) { - boolean nullable = ctx.isnull != null; - TypeExpression key = ctx.key.accept(this); - TypeExpression value = ctx.value.accept(this); + final boolean nullable = ctx.isnull != null; + final TypeExpression key = ctx.key.accept(this); + final TypeExpression value = ctx.value.accept(this); if (key instanceof Type && value instanceof Type) { return withNull(nullable).map((Type) key, (Type) value); } @@ -388,20 +389,20 @@ public TypeExpression visitMap(final SubstraitTypeParser.MapContext ctx) { return withNullE(nullable).mapE(key, value); } - private TypeCreator withNull(SubstraitTypeParser.ScalarTypeContext required) { + private TypeCreator withNull(final SubstraitTypeParser.ScalarTypeContext required) { return Type.withNullability( ((SubstraitTypeParser.TypeContext) required.parent).isnull != null); } - private TypeCreator withNull(boolean nullable) { + private TypeCreator withNull(final boolean nullable) { return Type.withNullability(nullable); } - private TypeExpressionCreator withNullE(boolean nullable) { + private TypeExpressionCreator withNullE(final boolean nullable) { return TypeExpression.withNullability(nullable); } - private ParameterizedTypeCreator withNullP(boolean nullable) { + private ParameterizedTypeCreator withNullP(final boolean nullable) { return ParameterizedType.withNullability(nullable); } @@ -420,7 +421,7 @@ public TypeExpression visitType(final SubstraitTypeParser.TypeContext ctx) { @Override public TypeExpression visitTypeParam(final SubstraitTypeParser.TypeParamContext ctx) { checkParameterizedOrExpression(); - boolean nullable = ctx.isnull != null; + final boolean nullable = ctx.isnull != null; return ParameterizedType.StringLiteral.builder() .nullable(nullable) .value(ctx.getText()) @@ -457,17 +458,18 @@ public TypeExpression visitTernary(final SubstraitTypeParser.TernaryContext ctx) public TypeExpression visitMultilineDefinition( final SubstraitTypeParser.MultilineDefinitionContext ctx) { checkExpression(); - List exprs = + final List exprs = ctx.expr().stream() .map(t -> t.accept(this)) .collect(java.util.stream.Collectors.toList()); - List identifiers = + final List identifiers = ctx.Identifier().stream() .map(t -> t.getText()) .collect(java.util.stream.Collectors.toList()); - TypeExpression finalExpr = ctx.finalType.accept(this); + final TypeExpression finalExpr = ctx.finalType.accept(this); - ImmutableTypeExpression.ReturnProgram.Builder bldr = TypeExpression.ReturnProgram.builder(); + final ImmutableTypeExpression.ReturnProgram.Builder bldr = + TypeExpression.ReturnProgram.builder(); for (int i = 0; i < exprs.size(); i++) { bldr.addAssignments( TypeExpression.ReturnProgram.Assignment.builder() @@ -483,7 +485,7 @@ public TypeExpression visitMultilineDefinition( @Override public TypeExpression visitBinaryExpr(final SubstraitTypeParser.BinaryExprContext ctx) { checkExpression(); - TypeExpression.BinaryOperation.OpType type = getBinaryExpressionType(ctx.op); + final TypeExpression.BinaryOperation.OpType type = getBinaryExpressionType(ctx.op); return TypeExpression.BinaryOperation.builder() .opType(type) .left(ctx.left.accept(this)) @@ -491,7 +493,7 @@ public TypeExpression visitBinaryExpr(final SubstraitTypeParser.BinaryExprContex .build(); } - private TypeExpression.BinaryOperation.OpType getBinaryExpressionType(Token token) { + private TypeExpression.BinaryOperation.OpType getBinaryExpressionType(final Token token) { switch (token.getText().toUpperCase(Locale.ROOT)) { case "+": return TypeExpression.BinaryOperation.OpType.ADD; @@ -537,8 +539,8 @@ public TypeExpression visitNumericExpression( } @Override - public TypeExpression visitAnyType(SubstraitTypeParser.AnyTypeContext anyType) { - boolean nullable = ((SubstraitTypeParser.TypeContext) anyType.parent).isnull != null; + public TypeExpression visitAnyType(final SubstraitTypeParser.AnyTypeContext anyType) { + final boolean nullable = ((SubstraitTypeParser.TypeContext) anyType.parent).isnull != null; return withNullP(nullable).parameter("any"); } @@ -548,7 +550,8 @@ public TypeExpression visitFunctionCall(final SubstraitTypeParser.FunctionCallCo if (ctx.expr().size() != 2) { throw new IllegalStateException("Only two argument functions exist for type expressions."); } - TypeExpression.BinaryOperation.OpType type = getFunctionType(ctx.Identifier().getSymbol()); + final TypeExpression.BinaryOperation.OpType type = + getFunctionType(ctx.Identifier().getSymbol()); return TypeExpression.BinaryOperation.builder() .opType(type) .left(ctx.expr(0).accept(this)) @@ -556,7 +559,7 @@ public TypeExpression visitFunctionCall(final SubstraitTypeParser.FunctionCallCo .build(); } - private TypeExpression.BinaryOperation.OpType getFunctionType(Token token) { + private TypeExpression.BinaryOperation.OpType getFunctionType(final Token token) { switch (token.getText().toUpperCase(Locale.ROOT)) { case "MIN": return TypeExpression.BinaryOperation.OpType.MIN; @@ -578,7 +581,7 @@ public TypeExpression visitLiteralNumber(final SubstraitTypeParser.LiteralNumber return i(Integer.parseInt(ctx.getText())); } - protected TypeExpression i(int val) { + protected TypeExpression i(final int val) { return TypeExpression.IntegerLiteral.builder().value(val).build(); } diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index b8e94793b..9f55d6d8f 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -16,35 +16,37 @@ public class TypeStringParser { private TypeStringParser() {} - public static Type parseSimple(String str, String urn) { + public static Type parseSimple(final String str, final String urn) { return parse(str, urn, ParseToPojo::type); } - public static ParameterizedType parseParameterized(String str, String urn) { + public static ParameterizedType parseParameterized(final String str, final String urn) { return parse(str, urn, ParseToPojo::parameterizedType); } - public static TypeExpression parseExpression(String str, String urn) { + public static TypeExpression parseExpression(final String str, final String urn) { return parse(str, urn, ParseToPojo::typeExpression); } - private static SubstraitTypeParser.StartContext parse(String str) { - SubstraitTypeLexer lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); + private static SubstraitTypeParser.StartContext parse(final String str) { + final SubstraitTypeLexer lexer = new SubstraitTypeLexer(CharStreams.fromString(str)); lexer.removeErrorListeners(); lexer.addErrorListener(TypeErrorListener.INSTANCE); - CommonTokenStream tokenStream = new CommonTokenStream(lexer); - SubstraitTypeParser parser = new io.substrait.type.SubstraitTypeParser(tokenStream); + final CommonTokenStream tokenStream = new CommonTokenStream(lexer); + final SubstraitTypeParser parser = new io.substrait.type.SubstraitTypeParser(tokenStream); parser.removeErrorListeners(); parser.addErrorListener(TypeErrorListener.INSTANCE); return parser.start(); } public static T parse( - String str, String urn, BiFunction func) { + final String str, + final String urn, + final BiFunction func) { return func.apply(urn, parse(str)); } - public static TypeExpression parse(String str, ParseToPojo.Visitor visitor) { + public static TypeExpression parse(final String str, final ParseToPojo.Visitor visitor) { return parse(str).accept(visitor); } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 691d4bce5..73ced2f80 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -13,12 +13,13 @@ abstract class BaseProtoConverter public abstract BaseProtoTypes typeContainer(boolean nullable); - public BaseProtoConverter(ExtensionCollector extensionCollector, String unsupportedMessage) { + public BaseProtoConverter( + final ExtensionCollector extensionCollector, final String unsupportedMessage) { super(unsupportedMessage); this.extensionCollector = extensionCollector; } - public final BaseProtoTypes typeContainer(NullableType literal) { + public final BaseProtoTypes typeContainer(final NullableType literal) { return typeContainer(literal.nullable()); } @@ -163,7 +164,7 @@ public final T visit(final Type.Map expr) { @Override public final T visit(final Type.UserDefined expr) { - int ref = + final int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return typeContainer(expr).userDefined(ref); } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 6a1bc3186..598d09d35 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -22,7 +22,7 @@ abstract class BaseProtoTypes { public final T INTERVAL_YEAR; public final T UUID; - public BaseProtoTypes(Type.Nullability nullability) { + public BaseProtoTypes(final Type.Nullability nullability) { this.nullability = nullability; BOOLEAN = wrap(Type.Boolean.newBuilder().setNullability(nullability).build()); I8 = wrap(Type.I8.newBuilder().setNullability(nullability).build()); @@ -43,59 +43,59 @@ public BaseProtoTypes(Type.Nullability nullability) { public abstract T fixedChar(I len); - public final T fixedChar(int len) { + public final T fixedChar(final int len) { return fixedChar(i(len)); } - public final T fixedChar(String len) { + public final T fixedChar(final String len) { return fixedChar(integerParam(len)); } - public final T varChar(int len) { + public final T varChar(final int len) { return varChar(i(len)); } - public final T varChar(String len) { + public final T varChar(final String len) { return varChar(integerParam(len)); } - public final T fixedBinary(int len) { + public final T fixedBinary(final int len) { return fixedBinary(i(len)); } - public final T fixedBinary(String len) { + public final T fixedBinary(final String len) { return fixedBinary(integerParam(len)); } - public final T decimal(int scale, int precision) { + public final T decimal(final int scale, final int precision) { return decimal(i(scale), i(precision)); } - public final T decimal(I scale, int precision) { + public final T decimal(final I scale, final int precision) { return decimal(scale, i(precision)); } - public final T decimal(int scale, I precision) { + public final T decimal(final int scale, final I precision) { return decimal(i(scale), precision); } - public final T intervalDay(int precision) { + public final T intervalDay(final int precision) { return intervalDay(i(precision)); } - public final T intervalCompound(int precision) { + public final T intervalCompound(final int precision) { return intervalCompound(i(precision)); } - public final T precisionTime(int precision) { + public final T precisionTime(final int precision) { return precisionTime(i(precision)); } - public final T precisionTimestamp(int precision) { + public final T precisionTimestamp(final int precision) { return precisionTimestamp(i(precision)); } - public final T precisionTimestampTZ(int precision) { + public final T precisionTimestampTZ(final int precision) { return precisionTimestampTZ(i(precision)); } @@ -119,7 +119,7 @@ public final T precisionTimestampTZ(int precision) { public abstract T intervalCompound(I precision); - public final T struct(T... types) { + public final T struct(final T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 4e0caa7c2..017e8b830 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -14,7 +14,7 @@ public class ParameterizedProtoConverter private static final BaseProtoTypes PARAMETERIZED_REQUIRED = new ParameterizedTypes(Type.Nullability.NULLABILITY_REQUIRED); - public ParameterizedProtoConverter(ExtensionCollector extensionCollector) { + public ParameterizedProtoConverter(final ExtensionCollector extensionCollector) { super(extensionCollector, "Parameterized types cannot include return type expressions."); } @@ -29,43 +29,45 @@ public ParameterizedType.IntegerOption i(final TypeExpression num) { } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.FixedChar expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.FixedChar expr) throws RuntimeException { return typeContainer(expr).fixedChar(expr.length().value()); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.VarChar expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.VarChar expr) throws RuntimeException { return typeContainer(expr).varChar(expr.length().value()); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.FixedBinary expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.FixedBinary expr) throws RuntimeException { return typeContainer(expr).fixedBinary(expr.length().value()); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.Decimal expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.Decimal expr) throws RuntimeException { return typeContainer(expr).decimal(i(expr.precision()), i(expr.scale())); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.PrecisionTimestamp expr) + public ParameterizedType visit( + final io.substrait.function.ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { return typeContainer(expr).precisionTimestamp(i(expr.precision())); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.PrecisionTimestampTZ expr) + public ParameterizedType visit( + final io.substrait.function.ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { return typeContainer(expr).precisionTimestampTZ(i(expr.precision())); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.Struct expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.Struct expr) throws RuntimeException { return typeContainer(expr) .struct( @@ -75,20 +77,21 @@ public ParameterizedType visit(io.substrait.function.ParameterizedType.Struct ex } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.ListType expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.ListType expr) throws RuntimeException { return typeContainer(expr).list(expr.name().accept(this)); } @Override - public ParameterizedType visit(io.substrait.function.ParameterizedType.Map expr) + public ParameterizedType visit(final io.substrait.function.ParameterizedType.Map expr) throws RuntimeException { return typeContainer(expr).map(expr.key().accept(this), expr.value().accept(this)); } @Override public ParameterizedType visit( - io.substrait.function.ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { + final io.substrait.function.ParameterizedType.StringLiteral stringLiteral) + throws RuntimeException { return ParameterizedType.newBuilder() .setTypeParameter( ParameterizedType.TypeParameter.newBuilder().setName(stringLiteral.value())) @@ -110,7 +113,7 @@ public ParameterizedType.IntegerOption visit(final TypeExpression.IntegerLiteral @Override public ParameterizedType.IntegerOption visit( - io.substrait.function.ParameterizedType.StringLiteral stringLiteral) + final io.substrait.function.ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { return ParameterizedType.IntegerOption.newBuilder() .setParameter( @@ -127,7 +130,7 @@ public ParameterizedTypes(final Type.Nullability nullability) { } @Override - public ParameterizedType fixedChar(ParameterizedType.IntegerOption len) { + public ParameterizedType fixedChar(final ParameterizedType.IntegerOption len) { return wrap( ParameterizedType.ParameterizedFixedChar.newBuilder() .setLength(len) @@ -150,12 +153,12 @@ public ParameterizedType.IntegerOption integerParam(final String name) { } @Override - protected ParameterizedType.IntegerOption i(int len) { + protected ParameterizedType.IntegerOption i(final int len) { return ParameterizedType.IntegerOption.newBuilder().setLiteral(len).build(); } @Override - public ParameterizedType varChar(ParameterizedType.IntegerOption len) { + public ParameterizedType varChar(final ParameterizedType.IntegerOption len) { return wrap( ParameterizedType.ParameterizedVarChar.newBuilder() .setLength(len) @@ -164,7 +167,7 @@ public ParameterizedType varChar(ParameterizedType.IntegerOption len) { } @Override - public ParameterizedType fixedBinary(ParameterizedType.IntegerOption len) { + public ParameterizedType fixedBinary(final ParameterizedType.IntegerOption len) { return wrap( ParameterizedType.ParameterizedFixedBinary.newBuilder() .setLength(len) @@ -174,7 +177,8 @@ public ParameterizedType fixedBinary(ParameterizedType.IntegerOption len) { @Override public ParameterizedType decimal( - ParameterizedType.IntegerOption scale, ParameterizedType.IntegerOption precision) { + final ParameterizedType.IntegerOption scale, + final ParameterizedType.IntegerOption precision) { return wrap( ParameterizedType.ParameterizedDecimal.newBuilder() .setScale(scale) @@ -184,7 +188,7 @@ public ParameterizedType decimal( } @Override - public ParameterizedType intervalDay(ParameterizedType.IntegerOption precision) { + public ParameterizedType intervalDay(final ParameterizedType.IntegerOption precision) { return wrap( ParameterizedType.ParameterizedIntervalDay.newBuilder() .setPrecision(precision) @@ -193,7 +197,7 @@ public ParameterizedType intervalDay(ParameterizedType.IntegerOption precision) } @Override - public ParameterizedType intervalCompound(ParameterizedType.IntegerOption precision) { + public ParameterizedType intervalCompound(final ParameterizedType.IntegerOption precision) { return wrap( ParameterizedType.ParameterizedIntervalCompound.newBuilder() .setPrecision(precision) @@ -202,7 +206,7 @@ public ParameterizedType intervalCompound(ParameterizedType.IntegerOption precis } @Override - public ParameterizedType precisionTime(ParameterizedType.IntegerOption precision) { + public ParameterizedType precisionTime(final ParameterizedType.IntegerOption precision) { return wrap( ParameterizedType.ParameterizedPrecisionTime.newBuilder() .setPrecision(precision) @@ -211,7 +215,7 @@ public ParameterizedType precisionTime(ParameterizedType.IntegerOption precision } @Override - public ParameterizedType precisionTimestamp(ParameterizedType.IntegerOption precision) { + public ParameterizedType precisionTimestamp(final ParameterizedType.IntegerOption precision) { return wrap( ParameterizedType.ParameterizedPrecisionTimestamp.newBuilder() .setPrecision(precision) @@ -220,7 +224,7 @@ public ParameterizedType precisionTimestamp(ParameterizedType.IntegerOption prec } @Override - public ParameterizedType precisionTimestampTZ(ParameterizedType.IntegerOption precision) { + public ParameterizedType precisionTimestampTZ(final ParameterizedType.IntegerOption precision) { return wrap( ParameterizedType.ParameterizedPrecisionTimestampTZ.newBuilder() .setPrecision(precision) @@ -229,7 +233,7 @@ public ParameterizedType precisionTimestampTZ(ParameterizedType.IntegerOption pr } @Override - public ParameterizedType struct(Iterable types) { + public ParameterizedType struct(final Iterable types) { return wrap( ParameterizedType.ParameterizedStruct.newBuilder() .addAllTypes(types) @@ -238,7 +242,7 @@ public ParameterizedType struct(Iterable types) { } @Override - public ParameterizedType list(ParameterizedType type) { + public ParameterizedType list(final ParameterizedType type) { return wrap( ParameterizedType.ParameterizedList.newBuilder() .setType(type) @@ -247,7 +251,7 @@ public ParameterizedType list(ParameterizedType type) { } @Override - public ParameterizedType map(ParameterizedType key, ParameterizedType value) { + public ParameterizedType map(final ParameterizedType key, final ParameterizedType value) { return wrap( ParameterizedType.ParameterizedMap.newBuilder() .setKey(key) @@ -257,14 +261,14 @@ public ParameterizedType map(ParameterizedType key, ParameterizedType value) { } @Override - public ParameterizedType userDefined(int ref) { + public ParameterizedType userDefined(final int ref) { throw new UnsupportedOperationException( "User defined types are not supported in Parameterized Types for now"); } @Override protected ParameterizedType wrap(final Object o) { - ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); + final ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); if (o instanceof Type.Boolean) { return bldr.setBool((Type.Boolean) o).build(); } else if (o instanceof Type.I8) { diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 95d42328a..ecbb06dfa 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -12,12 +12,12 @@ public class ProtoTypeConverter { private final SimpleExtension.ExtensionCollection extensions; public ProtoTypeConverter( - ExtensionLookup lookup, SimpleExtension.ExtensionCollection extensions) { + final ExtensionLookup lookup, final SimpleExtension.ExtensionCollection extensions) { this.lookup = lookup; this.extensions = extensions; } - public Type from(io.substrait.proto.Type type) { + public Type from(final io.substrait.proto.Type type) { switch (type.getKindCase()) { case BOOL: return n(type.getBool().getNullability()).BOOLEAN; @@ -88,8 +88,8 @@ public Type from(io.substrait.proto.Type type) { return fromMap(type.getMap()); case USER_DEFINED: { - io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); - SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); + final io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); + final SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); } case USER_DEFINED_TYPE_REFERENCE: @@ -101,19 +101,19 @@ public Type from(io.substrait.proto.Type type) { } } - public Type.ListType fromList(io.substrait.proto.Type.List list) { + public Type.ListType fromList(final io.substrait.proto.Type.List list) { return n(list.getNullability()).list(from(list.getType())); } - public Type.Map fromMap(io.substrait.proto.Type.Map map) { + public Type.Map fromMap(final io.substrait.proto.Type.Map map) { return n(map.getNullability()).map(from(map.getKey()), from(map.getValue())); } - public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) { + public static boolean isNullable(final io.substrait.proto.Type.Nullability nullability) { return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability; } - private static TypeCreator n(io.substrait.proto.Type.Nullability n) { + private static TypeCreator n(final io.substrait.proto.Type.Nullability n) { return n == io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 96cddd395..8a71f8bcd 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -16,7 +16,7 @@ public class TypeExpressionProtoVisitor private static final DerivationTypes DERIVATION_REQUIRED = new DerivationTypes(Type.Nullability.NULLABILITY_REQUIRED); - public TypeExpressionProtoVisitor(ExtensionCollector extensionCollector) { + public TypeExpressionProtoVisitor(final ExtensionCollector extensionCollector) { super(extensionCollector, "Unexpected expression type. This shouldn't happen."); } @@ -28,7 +28,7 @@ public BaseProtoTypes typeContainer( @Override public DerivationExpression visit(final TypeExpression.BinaryOperation expr) { - DerivationExpression.BinaryOp.BinaryOpType opType = getDerivationOpType(expr.opType()); + final DerivationExpression.BinaryOp.BinaryOpType opType = getDerivationOpType(expr.opType()); return DerivationExpression.newBuilder() .setBinaryOp( DerivationExpression.BinaryOp.newBuilder() @@ -40,7 +40,7 @@ public DerivationExpression visit(final TypeExpression.BinaryOperation expr) { } private DerivationExpression.BinaryOp.BinaryOpType getDerivationOpType( - TypeExpression.BinaryOperation.OpType type) { + final TypeExpression.BinaryOperation.OpType type) { switch (type) { case ADD: return DerivationExpression.BinaryOp.BinaryOpType.BINARY_OP_TYPE_PLUS; @@ -95,7 +95,7 @@ public DerivationExpression visit(final TypeExpression.IntegerLiteral expr) { @Override public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { - List assignments = + final List assignments = expr.assignments().stream() .map( a -> @@ -104,7 +104,7 @@ public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { .setExpression(a.expr().accept(this)) .build()) .collect(java.util.stream.Collectors.toList()); - DerivationExpression finalExpr = expr.finalExpression().accept(this); + final DerivationExpression finalExpr = expr.finalExpression().accept(this); return DerivationExpression.newBuilder() .setReturnProgram( DerivationExpression.ReturnProgram.newBuilder() @@ -115,47 +115,47 @@ public DerivationExpression visit(final TypeExpression.ReturnProgram expr) { } @Override - public DerivationExpression visit(ParameterizedType.FixedChar expr) { + public DerivationExpression visit(final ParameterizedType.FixedChar expr) { return typeContainer(expr).fixedChar(expr.length().value()); } @Override - public DerivationExpression visit(ParameterizedType.VarChar expr) { + public DerivationExpression visit(final ParameterizedType.VarChar expr) { return typeContainer(expr).varChar(expr.length().value()); } @Override - public DerivationExpression visit(ParameterizedType.FixedBinary expr) { + public DerivationExpression visit(final ParameterizedType.FixedBinary expr) { return typeContainer(expr).fixedBinary(expr.length().value()); } @Override - public DerivationExpression visit(ParameterizedType.Decimal expr) { + public DerivationExpression visit(final ParameterizedType.Decimal expr) { return typeContainer(expr).decimal(expr.precision().accept(this), expr.scale().accept(this)); } @Override - public DerivationExpression visit(ParameterizedType.IntervalDay expr) { + public DerivationExpression visit(final ParameterizedType.IntervalDay expr) { return typeContainer(expr).intervalDay(expr.precision().accept(this)); } @Override - public DerivationExpression visit(ParameterizedType.IntervalCompound expr) { + public DerivationExpression visit(final ParameterizedType.IntervalCompound expr) { return typeContainer(expr).intervalCompound(expr.precision().accept(this)); } @Override - public DerivationExpression visit(ParameterizedType.PrecisionTimestamp expr) { + public DerivationExpression visit(final ParameterizedType.PrecisionTimestamp expr) { return typeContainer(expr).precisionTimestamp(expr.precision().accept(this)); } @Override - public DerivationExpression visit(TypeExpression.PrecisionTimestampTZ expr) { + public DerivationExpression visit(final TypeExpression.PrecisionTimestampTZ expr) { return typeContainer(expr).precisionTimestampTZ(expr.precision().accept(this)); } @Override - public DerivationExpression visit(ParameterizedType.Struct expr) { + public DerivationExpression visit(final ParameterizedType.Struct expr) { return typeContainer(expr) .struct( expr.fields().stream() @@ -164,42 +164,42 @@ public DerivationExpression visit(ParameterizedType.Struct expr) { } @Override - public DerivationExpression visit(ParameterizedType.ListType expr) { + public DerivationExpression visit(final ParameterizedType.ListType expr) { return typeContainer(expr).list(expr.name().accept(this)); } @Override - public DerivationExpression visit(ParameterizedType.Map expr) { + public DerivationExpression visit(final ParameterizedType.Map expr) { return typeContainer(expr).map(expr.key().accept(this), expr.value().accept(this)); } @Override - public DerivationExpression visit(ParameterizedType.StringLiteral stringLiteral) { + public DerivationExpression visit(final ParameterizedType.StringLiteral stringLiteral) { return DerivationExpression.newBuilder().setTypeParameterName(stringLiteral.value()).build(); } @Override - public DerivationExpression visit(TypeExpression.FixedChar expr) { + public DerivationExpression visit(final TypeExpression.FixedChar expr) { return typeContainer(expr).fixedChar(expr.length().accept(this)); } @Override - public DerivationExpression visit(TypeExpression.VarChar expr) { + public DerivationExpression visit(final TypeExpression.VarChar expr) { return typeContainer(expr).varChar(expr.length().accept(this)); } @Override - public DerivationExpression visit(TypeExpression.FixedBinary expr) { + public DerivationExpression visit(final TypeExpression.FixedBinary expr) { return typeContainer(expr).fixedBinary(expr.length().accept(this)); } @Override - public DerivationExpression visit(TypeExpression.Decimal expr) { + public DerivationExpression visit(final TypeExpression.Decimal expr) { return typeContainer(expr).decimal(expr.precision().accept(this), expr.scale().accept(this)); } @Override - public DerivationExpression visit(TypeExpression.Struct expr) { + public DerivationExpression visit(final TypeExpression.Struct expr) { return typeContainer(expr) .struct( expr.fields().stream() @@ -208,12 +208,12 @@ public DerivationExpression visit(TypeExpression.Struct expr) { } @Override - public DerivationExpression visit(TypeExpression.ListType expr) { + public DerivationExpression visit(final TypeExpression.ListType expr) { return typeContainer(expr).list(expr.elementType().accept(this)); } @Override - public DerivationExpression visit(TypeExpression.Map expr) { + public DerivationExpression visit(final TypeExpression.Map expr) { return typeContainer(expr).map(expr.key().accept(this), expr.value().accept(this)); } @@ -225,7 +225,7 @@ public DerivationTypes(final Type.Nullability nullability) { } @Override - public DerivationExpression fixedChar(DerivationExpression len) { + public DerivationExpression fixedChar(final DerivationExpression len) { return wrap( DerivationExpression.ExpressionFixedChar.newBuilder() .setLength(len) @@ -244,7 +244,7 @@ public DerivationExpression integerParam(final String name) { } @Override - public DerivationExpression varChar(DerivationExpression len) { + public DerivationExpression varChar(final DerivationExpression len) { return wrap( DerivationExpression.ExpressionVarChar.newBuilder() .setLength(len) @@ -253,7 +253,7 @@ public DerivationExpression varChar(DerivationExpression len) { } @Override - public DerivationExpression fixedBinary(DerivationExpression len) { + public DerivationExpression fixedBinary(final DerivationExpression len) { return wrap( DerivationExpression.ExpressionFixedBinary.newBuilder() .setLength(len) @@ -263,7 +263,7 @@ public DerivationExpression fixedBinary(DerivationExpression len) { @Override public DerivationExpression decimal( - DerivationExpression scale, DerivationExpression precision) { + final DerivationExpression scale, final DerivationExpression precision) { return wrap( DerivationExpression.ExpressionDecimal.newBuilder() .setScale(scale) @@ -273,7 +273,7 @@ public DerivationExpression decimal( } @Override - public DerivationExpression precisionTime(DerivationExpression precision) { + public DerivationExpression precisionTime(final DerivationExpression precision) { return wrap( DerivationExpression.ExpressionPrecisionTime.newBuilder() .setPrecision(precision) @@ -282,7 +282,7 @@ public DerivationExpression precisionTime(DerivationExpression precision) { } @Override - public DerivationExpression precisionTimestamp(DerivationExpression precision) { + public DerivationExpression precisionTimestamp(final DerivationExpression precision) { return wrap( DerivationExpression.ExpressionPrecisionTimestamp.newBuilder() .setPrecision(precision) @@ -291,7 +291,7 @@ public DerivationExpression precisionTimestamp(DerivationExpression precision) { } @Override - public DerivationExpression precisionTimestampTZ(DerivationExpression precision) { + public DerivationExpression precisionTimestampTZ(final DerivationExpression precision) { return wrap( DerivationExpression.ExpressionPrecisionTimestampTZ.newBuilder() .setPrecision(precision) @@ -300,7 +300,7 @@ public DerivationExpression precisionTimestampTZ(DerivationExpression precision) } @Override - public DerivationExpression intervalDay(DerivationExpression precision) { + public DerivationExpression intervalDay(final DerivationExpression precision) { return wrap( DerivationExpression.ExpressionIntervalDay.newBuilder() .setPrecision(precision) @@ -309,7 +309,7 @@ public DerivationExpression intervalDay(DerivationExpression precision) { } @Override - public DerivationExpression intervalCompound(DerivationExpression precision) { + public DerivationExpression intervalCompound(final DerivationExpression precision) { return wrap( DerivationExpression.ExpressionIntervalCompound.newBuilder() .setPrecision(precision) @@ -318,7 +318,7 @@ public DerivationExpression intervalCompound(DerivationExpression precision) { } @Override - public DerivationExpression struct(Iterable types) { + public DerivationExpression struct(final Iterable types) { return wrap( DerivationExpression.ExpressionStruct.newBuilder() .addAllTypes(types) @@ -326,12 +326,12 @@ public DerivationExpression struct(Iterable types) { .build()); } - public DerivationExpression param(String name) { + public DerivationExpression param(final String name) { return DerivationExpression.newBuilder().setTypeParameterName(name).build(); } @Override - public DerivationExpression list(DerivationExpression type) { + public DerivationExpression list(final DerivationExpression type) { return wrap( DerivationExpression.ExpressionList.newBuilder() .setType(type) @@ -340,7 +340,8 @@ public DerivationExpression list(DerivationExpression type) { } @Override - public DerivationExpression map(DerivationExpression key, DerivationExpression value) { + public DerivationExpression map( + final DerivationExpression key, final DerivationExpression value) { return wrap( DerivationExpression.ExpressionMap.newBuilder() .setKey(key) @@ -350,14 +351,14 @@ public DerivationExpression map(DerivationExpression key, DerivationExpression v } @Override - public DerivationExpression userDefined(int ref) { + public DerivationExpression userDefined(final int ref) { throw new UnsupportedOperationException( "User defined types are not supported in Derivation Expressions for now"); } @Override protected DerivationExpression wrap(final Object o) { - DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); + final DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); if (o instanceof Type.Boolean) { return bldr.setBool((Type.Boolean) o).build(); } else if (o instanceof Type.I8) { diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 2d0ed0ffc..1af97855a 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -10,11 +10,11 @@ public class TypeProtoConverter extends BaseProtoConverter { private static final BaseProtoTypes REQUIRED = new Types(Type.Nullability.NULLABILITY_REQUIRED); - public TypeProtoConverter(ExtensionCollector extensionCollector) { + public TypeProtoConverter(final ExtensionCollector extensionCollector) { super(extensionCollector, "Type literals cannot contain parameters or expressions."); } - public io.substrait.proto.Type toProto(io.substrait.type.Type type) { + public io.substrait.proto.Type toProto(final io.substrait.type.Type type) { return type.accept(this); } @@ -30,7 +30,7 @@ public Types(final Type.Nullability nullability) { } @Override - public Type fixedChar(Integer len) { + public Type fixedChar(final Integer len) { return wrap(Type.FixedChar.newBuilder().setLength(len).setNullability(nullability).build()); } @@ -47,17 +47,17 @@ public Integer integerParam(final String name) { } @Override - public Type varChar(Integer len) { + public Type varChar(final Integer len) { return wrap(Type.VarChar.newBuilder().setLength(len).setNullability(nullability).build()); } @Override - public Type fixedBinary(Integer len) { + public Type fixedBinary(final Integer len) { return wrap(Type.FixedBinary.newBuilder().setLength(len).setNullability(nullability).build()); } @Override - public Type decimal(Integer scale, Integer precision) { + public Type decimal(final Integer scale, final Integer precision) { return wrap( Type.Decimal.newBuilder() .setScale(scale) @@ -67,7 +67,7 @@ public Type decimal(Integer scale, Integer precision) { } @Override - public Type intervalDay(Integer precision) { + public Type intervalDay(final Integer precision) { return wrap( Type.IntervalDay.newBuilder() .setPrecision(precision) @@ -76,7 +76,7 @@ public Type intervalDay(Integer precision) { } @Override - public Type intervalCompound(Integer precision) { + public Type intervalCompound(final Integer precision) { return wrap( Type.IntervalCompound.newBuilder() .setPrecision(precision) @@ -85,7 +85,7 @@ public Type intervalCompound(Integer precision) { } @Override - public Type precisionTime(Integer precision) { + public Type precisionTime(final Integer precision) { return wrap( Type.PrecisionTime.newBuilder() .setPrecision(precision) @@ -94,7 +94,7 @@ public Type precisionTime(Integer precision) { } @Override - public Type precisionTimestamp(Integer precision) { + public Type precisionTimestamp(final Integer precision) { return wrap( Type.PrecisionTimestamp.newBuilder() .setPrecision(precision) @@ -103,7 +103,7 @@ public Type precisionTimestamp(Integer precision) { } @Override - public Type precisionTimestampTZ(Integer precision) { + public Type precisionTimestampTZ(final Integer precision) { return wrap( Type.PrecisionTimestampTZ.newBuilder() .setPrecision(precision) @@ -112,30 +112,30 @@ public Type precisionTimestampTZ(Integer precision) { } @Override - public Type struct(Iterable types) { + public Type struct(final Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); } @Override - public Type list(Type type) { + public Type list(final Type type) { return wrap(Type.List.newBuilder().setType(type).setNullability(nullability).build()); } @Override - public Type map(Type key, Type value) { + public Type map(final Type key, final Type value) { return wrap( Type.Map.newBuilder().setKey(key).setValue(value).setNullability(nullability).build()); } @Override - public Type userDefined(int ref) { + public Type userDefined(final int ref) { return wrap( Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); } @Override protected Type wrap(final Object o) { - Type.Builder bldr = Type.newBuilder(); + final Type.Builder bldr = Type.newBuilder(); if (o instanceof Type.Boolean) { return bldr.setBool((Type.Boolean) o).build(); } else if (o instanceof Type.I8) { diff --git a/core/src/main/java/io/substrait/util/DecimalUtil.java b/core/src/main/java/io/substrait/util/DecimalUtil.java index 18886059d..4b0b84de4 100644 --- a/core/src/main/java/io/substrait/util/DecimalUtil.java +++ b/core/src/main/java/io/substrait/util/DecimalUtil.java @@ -41,12 +41,13 @@ public class DecimalUtil { * @param byteWidth * @return */ - public static BigDecimal getBigDecimalFromBytes(byte[] value, int scale, int byteWidth) { - byte[] reversed = new byte[value.length]; + public static BigDecimal getBigDecimalFromBytes( + final byte[] value, final int scale, final int byteWidth) { + final byte[] reversed = new byte[value.length]; for (int i = 0; i < byteWidth; i++) { reversed[byteWidth - 1 - i] = value[i]; } - BigInteger unscaledValue = new BigInteger(reversed); + final BigInteger unscaledValue = new BigInteger(reversed); return new BigDecimal(unscaledValue, scale); } @@ -59,17 +60,18 @@ public static BigDecimal getBigDecimalFromBytes(byte[] value, int scale, int byt * @param byteWidth * @return */ - public static byte[] encodeDecimalIntoBytes(BigDecimal decimal, int scale, int byteWidth) { - BigDecimal scaledDecimal = decimal.multiply(powerOfTen(scale)); - byte[] bytes = scaledDecimal.toBigInteger().toByteArray(); + public static byte[] encodeDecimalIntoBytes( + final BigDecimal decimal, final int scale, final int byteWidth) { + final BigDecimal scaledDecimal = decimal.multiply(powerOfTen(scale)); + final byte[] bytes = scaledDecimal.toBigInteger().toByteArray(); if (bytes.length > byteWidth) { throw new UnsupportedOperationException( "Decimal size greater than " + byteWidth + " bytes: " + bytes.length); } - byte[] encodedBytes = new byte[byteWidth]; - byte padByte = bytes[0] < 0 ? minus_one : zero; + final byte[] encodedBytes = new byte[byteWidth]; + final byte padByte = bytes[0] < 0 ? minus_one : zero; // Decimal stored as native-endian, need to swap data bytes if LE - byte[] bytesLE = new byte[bytes.length]; + final byte[] bytesLE = new byte[bytes.length]; for (int i = 0; i < bytes.length; i++) { bytesLE[i] = bytes[bytes.length - 1 - i]; } @@ -85,11 +87,11 @@ public static byte[] encodeDecimalIntoBytes(BigDecimal decimal, int scale, int b return encodedBytes; } - private static BigDecimal powerOfTen(int scale) { + private static BigDecimal powerOfTen(final int scale) { if (scale < POWER_OF_10.length) { return new BigDecimal(POWER_OF_10[scale]); } else { - int length = POWER_OF_10.length; + final int length = POWER_OF_10.length; BigDecimal bd = new BigDecimal(POWER_OF_10[length - 1]); for (int i = length - 1; i < scale; i++) { diff --git a/core/src/main/java/io/substrait/util/Util.java b/core/src/main/java/io/substrait/util/Util.java index 1b2ecb2ac..02b4eca5d 100644 --- a/core/src/main/java/io/substrait/util/Util.java +++ b/core/src/main/java/io/substrait/util/Util.java @@ -4,7 +4,7 @@ public class Util { - public static Supplier memoize(Supplier supplier) { + public static Supplier memoize(final Supplier supplier) { return new Memoizer(supplier); } @@ -12,9 +12,9 @@ private static class Memoizer implements Supplier { private boolean retrieved; private T value; - private Supplier delegate; + private final Supplier delegate; - public Memoizer(Supplier delegate) { + public Memoizer(final Supplier delegate) { this.delegate = delegate; } @@ -32,11 +32,11 @@ public static class IntRange { private final int startInclusive; private final int endExclusive; - public static IntRange of(int startInclusive, int endExclusive) { + public static IntRange of(final int startInclusive, final int endExclusive) { return new IntRange(startInclusive, endExclusive); } - private IntRange(int startInclusive, int endExclusive) { + private IntRange(final int startInclusive, final int endExclusive) { this.startInclusive = startInclusive; this.endExclusive = endExclusive; } @@ -49,7 +49,7 @@ public int getEndExclusive() { return endExclusive; } - public boolean within(int val) { + public boolean within(final int val) { return val >= startInclusive && val < endExclusive; } } diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..dfc73994f 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -25,9 +25,9 @@ public abstract class TestBase { protected ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, defaultExtensionCollection); - protected void verifyRoundTrip(Rel rel) { - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); + protected void verifyRoundTrip(final Rel rel) { + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } } diff --git a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java index b84ef8bd2..eea75eab1 100644 --- a/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java +++ b/core/src/test/java/io/substrait/extendedexpression/ExtendedExpressionRoundTripTest.java @@ -34,16 +34,16 @@ private static Stream expressionReferenceProvider() { @ParameterizedTest @MethodSource("expressionReferenceProvider") - void testRoundTrip(ExtendedExpression.ExpressionReferenceBase expressionReference) { - List expressionReferences = new ArrayList<>(); + void testRoundTrip(final ExtendedExpression.ExpressionReferenceBase expressionReference) { + final List expressionReferences = new ArrayList<>(); expressionReferences.add(expressionReference); - NamedStruct namedStruct = getImmutableNamedStruct(); + final NamedStruct namedStruct = getImmutableNamedStruct(); assertExtendedExpressionOperation(expressionReferences, namedStruct); } @Test void getNoExpressionDefined() { - IllegalStateException illegalStateException = + final IllegalStateException illegalStateException = Assertions.assertThrows( IllegalStateException.class, () -> ImmutableExpressionReference.builder().addOutputNames("new-column").build()); @@ -56,7 +56,7 @@ void getNoExpressionDefined() { @Test void getNoAggregateFunctionDefined() { - IllegalStateException illegalStateException = + final IllegalStateException illegalStateException = Assertions.assertThrows( IllegalStateException.class, () -> @@ -87,7 +87,7 @@ private static ImmutableExpressionReference getFieldReferenceExpression() { } private static ImmutableExpressionReference getScalarFunctionExpression() { - Expression.ScalarFunctionInvocation scalarFunctionInvocation = + final Expression.ScalarFunctionInvocation scalarFunctionInvocation = new SubstraitBuilder(defaultExtensionCollection) .scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC_DECIMAL, @@ -106,7 +106,7 @@ private static ImmutableExpressionReference getScalarFunctionExpression() { } private static ImmutableAggregateFunctionReference getAggregateFunctionReference() { - Aggregate.Measure measure = + final Aggregate.Measure measure = Aggregate.Measure.builder() .function( AggregateFunctionInvocation.builder() @@ -140,22 +140,22 @@ private static NamedStruct getImmutableNamedStruct() { } private static void assertExtendedExpressionOperation( - List expressionReferences, - NamedStruct namedStruct) { + final List expressionReferences, + final NamedStruct namedStruct) { // initial pojo - ExtendedExpression extendedExpressionPojoInitial = + final ExtendedExpression extendedExpressionPojoInitial = ImmutableExtendedExpression.builder() .referredExpressions(expressionReferences) .baseSchema(namedStruct) .build(); // proto - io.substrait.proto.ExtendedExpression extendedExpressionProto = + final io.substrait.proto.ExtendedExpression extendedExpressionProto = new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoInitial); // get pojo from proto - ExtendedExpression extendedExpressionPojoFinal = + final ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProto); Assertions.assertEquals(extendedExpressionPojoInitial, extendedExpressionPojoFinal); diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java index c2c1a58b2..9974f138b 100644 --- a/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectionMergeTest.java @@ -11,7 +11,7 @@ class ExtensionCollectionMergeTest { @Test void testMergeCollectionsWithDifferentUriUrnMappings() { - String yaml1 = + final String yaml1 = "%YAML 1.2\n" + "---\n" + "urn: extension:ns1:collection1\n" @@ -21,7 +21,7 @@ void testMergeCollectionsWithDifferentUriUrnMappings() { + " - args: []\n" + " return: boolean\n"; - String yaml2 = + final String yaml2 = "%YAML 1.2\n" + "---\n" + "urn: extension:ns2:collection2\n" @@ -31,12 +31,12 @@ void testMergeCollectionsWithDifferentUriUrnMappings() { + " - args: []\n" + " return: i32\n"; - SimpleExtension.ExtensionCollection collection1 = + final SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("uri1://extensions", yaml1); - SimpleExtension.ExtensionCollection collection2 = + final SimpleExtension.ExtensionCollection collection2 = SimpleExtension.load("uri2://extensions", yaml2); - SimpleExtension.ExtensionCollection merged = collection1.merge(collection2); + final SimpleExtension.ExtensionCollection merged = collection1.merge(collection2); assertEquals("extension:ns1:collection1", merged.getUrnFromUri("uri1://extensions")); assertEquals("extension:ns2:collection2", merged.getUrnFromUri("uri2://extensions")); @@ -48,7 +48,7 @@ void testMergeCollectionsWithDifferentUriUrnMappings() { @Test void testMergeCollectionsWithIdenticalMappings() { - String yaml = + final String yaml = "%YAML 1.2\n" + "---\n" + "urn: extension:shared:extension\n" @@ -58,10 +58,12 @@ void testMergeCollectionsWithIdenticalMappings() { + " - args: []\n" + " return: boolean\n"; - SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("shared://uri", yaml); - SimpleExtension.ExtensionCollection collection2 = SimpleExtension.load("shared://uri", yaml); + final SimpleExtension.ExtensionCollection collection1 = + SimpleExtension.load("shared://uri", yaml); + final SimpleExtension.ExtensionCollection collection2 = + SimpleExtension.load("shared://uri", yaml); - SimpleExtension.ExtensionCollection merged = + final SimpleExtension.ExtensionCollection merged = assertDoesNotThrow(() -> collection1.merge(collection2)); assertEquals("extension:shared:extension", merged.getUrnFromUri("shared://uri")); @@ -70,17 +72,18 @@ void testMergeCollectionsWithIdenticalMappings() { @Test void testMergeCollectionsWithConflictingMappings() { - String yaml1 = + final String yaml1 = "%YAML 1.2\n" + "---\n" + "urn: extension:conflict:urn1\n" + "scalar_functions: []\n"; - String yaml2 = + final String yaml2 = "%YAML 1.2\n" + "---\n" + "urn: extension:conflict:urn2\n" + "scalar_functions: []\n"; - SimpleExtension.ExtensionCollection collection1 = SimpleExtension.load("conflict://uri", yaml1); - SimpleExtension.ExtensionCollection collection2 = + final SimpleExtension.ExtensionCollection collection1 = + SimpleExtension.load("conflict://uri", yaml1); + final SimpleExtension.ExtensionCollection collection2 = SimpleExtension.load("conflict://uri", yaml2); // Same URI, different URN - IllegalArgumentException exception = + final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> collection1.merge(collection2)); assertTrue(exception.getMessage().contains("Key already exists in map with different value")); } diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java index 8be2d6138..0367304ee 100644 --- a/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectionUriUrnTest.java @@ -11,14 +11,14 @@ class ExtensionCollectionUriUrnTest { @Test void testHasUrnAndHasUri() { - String yamlContent = + final String yamlContent = "%YAML 1.2\n" + "---\n" + "urn: extension:test:exists\n" + "scalar_functions:\n" + " - name: test_function\n"; - SimpleExtension.ExtensionCollection collection = + final SimpleExtension.ExtensionCollection collection = SimpleExtension.load("file:///tmp/test.yaml", yamlContent); assertTrue(collection.getUrnFromUri("file:///tmp/test.yaml") != null); @@ -29,10 +29,10 @@ void testHasUrnAndHasUri() { @Test void testGetNonexistentMappings() { - String yamlContent = + final String yamlContent = "%YAML 1.2\n" + "---\n" + "urn: extension:test:minimal\n" + "scalar_functions: []\n"; - SimpleExtension.ExtensionCollection collection = + final SimpleExtension.ExtensionCollection collection = SimpleExtension.load("minimal://extension", yamlContent); assertNull(collection.getUrnFromUri("nonexistent://uri")); @@ -41,17 +41,17 @@ void testGetNonexistentMappings() { @Test void testEmptyUriThrowsException() { - String yamlContent = + final String yamlContent = "%YAML 1.2\n" + "---\n" + "urn: extension:test:empty\n" + "scalar_functions: []\n"; - IllegalArgumentException exception = + final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("", yamlContent)); assertTrue(exception.getMessage().contains("URI cannot be null or empty")); } @Test void testNullUriThrowsException() { - String yamlContent = + final String yamlContent = "%YAML 1.2\n" + "---\n" + "urn: extension:test:null\n" + "scalar_functions: []\n"; // The system throws NPE when null is passed, which is expected behavior diff --git a/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java b/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java index abfec1310..706277a01 100644 --- a/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java +++ b/core/src/test/java/io/substrait/extension/ExtensionCollectorUriUrnTest.java @@ -9,29 +9,29 @@ class ExtensionCollectorUriUrnTest { @Test void testExtensionCollectorScalarFuncWithoutURI() { - String uri = "test://uri"; - BidiMap uriUrnMap = new BidiMap(); + final String uri = "test://uri"; + final BidiMap uriUrnMap = new BidiMap(); uriUrnMap.put(uri, "extension:test:basic"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - ExtensionCollector collector = new ExtensionCollector(extensionCollection); + final ExtensionCollector collector = new ExtensionCollector(extensionCollection); - SimpleExtension.ScalarFunctionVariant func = + final SimpleExtension.ScalarFunctionVariant func = ImmutableSimpleExtension.ScalarFunctionVariant.builder() .urn("extension:test:basic") .name("test_func") .returnType(io.substrait.function.TypeExpressionCreator.REQUIRED.BOOLEAN) .build(); - int functionRef = collector.getFunctionReference(func); + final int functionRef = collector.getFunctionReference(func); assertEquals(1, functionRef); - Plan.Builder planBuilder = Plan.newBuilder(); + final Plan.Builder planBuilder = Plan.newBuilder(); collector.addExtensionsToPlan(planBuilder); - Plan plan = planBuilder.build(); + final Plan plan = planBuilder.build(); assertEquals(1, plan.getExtensionUrnsCount()); assertEquals("extension:test:basic", plan.getExtensionUrns(0).getUrn()); diff --git a/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java index 71cc9a337..92971be4b 100644 --- a/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java +++ b/core/src/test/java/io/substrait/extension/ImmutableExtensionLookupUriUrnTest.java @@ -15,26 +15,26 @@ class ImmutableExtensionLookupUriUrnTest { @Test void testUrnResolutionWorks() { // Create URN-only plan (normal case) - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(1) .setUrn("extension:test:urn") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("test_func") .setExtensionUrnReference(1) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); // Test with no ExtensionCollection (no URI/URN mapping available) - ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); assertEquals("extension:test:urn", lookup.functionAnchorMap.get(1).urn()); assertEquals("test_func", lookup.functionAnchorMap.get(1).key()); @@ -43,33 +43,33 @@ void testUrnResolutionWorks() { @Test void testUriToUrnFallbackWorks() { // Create an ExtensionCollection with URI/URN mapping - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/extensions/test", "extension:test:mapped"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); // Create URI-only plan (legacy case) - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(1) .setUri("http://example.com/extensions/test") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("legacy_func") .setExtensionUriReference(1) // References the URI anchor (deprecated field) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); // Test with URI/URN mapping - should resolve URI to URN - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:mapped", lookup.functionAnchorMap.get(1).urn()); @@ -79,26 +79,26 @@ void testUriToUrnFallbackWorks() { @Test void testUriWithoutMappingThrowsError() { // Create URI-only plan without mapping - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(1) .setUri("http://example.com/unmapped") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("unmapped_func") .setExtensionUriReference(1) // References the URI anchor .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); // Should throw error - URI present but no mapping available - IllegalStateException exception = + final IllegalStateException exception = assertThrows( IllegalStateException.class, () -> { @@ -113,20 +113,20 @@ void testUriWithoutMappingThrowsError() { @Test void testMissingUrnAndUriThrowsError() { // Create plan with missing URN/URI reference - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("missing_func") .setExtensionUrnReference(999) // Non-existent reference .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensions(decl).build(); // Should throw error - neither URN nor URI found - IllegalStateException exception = + final IllegalStateException exception = assertThrows( IllegalStateException.class, () -> { @@ -144,25 +144,25 @@ void testMissingUrnAndUriThrowsError() { @Test void testFunctionCase1_NonZeroUrnReference() { // Case 1: Non-zero URN reference resolves - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(1) .setUrn("extension:test:case1") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("case1_func") .setExtensionUrnReference(1) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); assertEquals("extension:test:case1", lookup.functionAnchorMap.get(1).urn()); assertEquals("case1_func", lookup.functionAnchorMap.get(1).key()); @@ -171,31 +171,31 @@ void testFunctionCase1_NonZeroUrnReference() { @Test void testFunctionCase2_NonZeroUriReference() { // Case 2: Non-zero URI reference resolves via mapping - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/case2", "extension:test:case2"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(1) .setUri("http://example.com/case2") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("case2_func") .setExtensionUriReference(1) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:case2", lookup.functionAnchorMap.get(1).urn()); @@ -205,25 +205,25 @@ void testFunctionCase2_NonZeroUriReference() { @Test void testFunctionCase3_ZeroBothResolveConsistent() { // Case 3: Both 0 references resolve to consistent URN - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/case3", "extension:test:case3"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(0) .setUrn("extension:test:case3") .build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(0) .setUri("http://example.com/case3") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("case3_func") @@ -231,17 +231,17 @@ void testFunctionCase3_ZeroBothResolveConsistent() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = + final Plan plan = Plan.newBuilder() .addExtensionUrns(urnProto) .addExtensionUris(uriProto) .addExtensions(decl) .build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:case3", lookup.functionAnchorMap.get(1).urn()); @@ -251,25 +251,25 @@ void testFunctionCase3_ZeroBothResolveConsistent() { @Test void testFunctionCase3_ZeroBothResolveConflict() { // Case 3: Both 0 references resolve but to different URNs - should throw - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/conflict", "extension:test:different"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(0) .setUrn("extension:test:original") .build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(0) .setUri("http://example.com/conflict") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("conflict_func") @@ -277,17 +277,17 @@ void testFunctionCase3_ZeroBothResolveConflict() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = + final Plan plan = Plan.newBuilder() .addExtensionUrns(urnProto) .addExtensionUris(uriProto) .addExtensions(decl) .build(); - IllegalStateException exception = + final IllegalStateException exception = assertThrows( IllegalStateException.class, () -> { @@ -301,13 +301,13 @@ void testFunctionCase3_ZeroBothResolveConflict() { @Test void testFunctionCase4_ZeroUrnOnly() { // Case 4: Only 0 URN reference resolves - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(0) .setUrn("extension:test:case4") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("case4_func") @@ -315,12 +315,12 @@ void testFunctionCase4_ZeroUrnOnly() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); assertEquals("extension:test:case4", lookup.functionAnchorMap.get(1).urn()); assertEquals("case4_func", lookup.functionAnchorMap.get(1).key()); @@ -329,19 +329,19 @@ void testFunctionCase4_ZeroUrnOnly() { @Test void testFunctionCase5_ZeroUriOnly() { // Case 5: Only 0 URI reference resolves - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/case5", "extension:test:case5"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(0) .setUri("http://example.com/case5") .build(); - SimpleExtensionDeclaration.ExtensionFunction func = + final SimpleExtensionDeclaration.ExtensionFunction func = SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(1) .setName("case5_func") @@ -349,12 +349,12 @@ void testFunctionCase5_ZeroUriOnly() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionFunction(func).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:case5", lookup.functionAnchorMap.get(1).urn()); @@ -368,25 +368,25 @@ void testFunctionCase5_ZeroUriOnly() { @Test void testTypeCase1_NonZeroUrnReference() { // Case 1: Non-zero URN reference resolves - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(1) .setUrn("extension:test:case1") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("case1_type") .setExtensionUrnReference(1) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); assertEquals("extension:test:case1", lookup.typeAnchorMap.get(1).urn()); assertEquals("case1_type", lookup.typeAnchorMap.get(1).key()); @@ -395,31 +395,31 @@ void testTypeCase1_NonZeroUrnReference() { @Test void testTypeCase2_NonZeroUriReference() { // Case 2: Non-zero URI reference resolves via mapping - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/case2", "extension:test:case2"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(1) .setUri("http://example.com/case2") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("case2_type") .setExtensionUriReference(1) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:case2", lookup.typeAnchorMap.get(1).urn()); @@ -429,25 +429,25 @@ void testTypeCase2_NonZeroUriReference() { @Test void testTypeCase3_ZeroBothResolveConsistent() { // Case 3: Both 0 references resolve to consistent URN - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/case3", "extension:test:case3"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(0) .setUrn("extension:test:case3") .build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(0) .setUri("http://example.com/case3") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("case3_type") @@ -455,17 +455,17 @@ void testTypeCase3_ZeroBothResolveConsistent() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = + final Plan plan = Plan.newBuilder() .addExtensionUrns(urnProto) .addExtensionUris(uriProto) .addExtensions(decl) .build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:case3", lookup.typeAnchorMap.get(1).urn()); @@ -475,25 +475,25 @@ void testTypeCase3_ZeroBothResolveConsistent() { @Test void testTypeCase3_ZeroBothResolveConflict() { // Case 3: Both 0 references resolve but to different URNs - should throw - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/conflict", "extension:test:different"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(0) .setUrn("extension:test:original") .build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(0) .setUri("http://example.com/conflict") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("conflict_type") @@ -501,17 +501,17 @@ void testTypeCase3_ZeroBothResolveConflict() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = + final Plan plan = Plan.newBuilder() .addExtensionUrns(urnProto) .addExtensionUris(uriProto) .addExtensions(decl) .build(); - IllegalStateException exception = + final IllegalStateException exception = assertThrows( IllegalStateException.class, () -> { @@ -525,13 +525,13 @@ void testTypeCase3_ZeroBothResolveConflict() { @Test void testTypeCase4_ZeroUrnOnly() { // Case 4: Only 0 URN reference resolves - SimpleExtensionURN urnProto = + final SimpleExtensionURN urnProto = SimpleExtensionURN.newBuilder() .setExtensionUrnAnchor(0) .setUrn("extension:test:case4") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("case4_type") @@ -539,12 +539,12 @@ void testTypeCase4_ZeroUrnOnly() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUrns(urnProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder().from(plan).build(); assertEquals("extension:test:case4", lookup.typeAnchorMap.get(1).urn()); assertEquals("case4_type", lookup.typeAnchorMap.get(1).key()); @@ -553,19 +553,19 @@ void testTypeCase4_ZeroUrnOnly() { @Test void testTypeCase5_ZeroUriOnly() { // Case 5: Only 0 URI reference resolves - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/case5", "extension:test:case5"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(0) .setUri("http://example.com/case5") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("case5_type") @@ -573,12 +573,12 @@ void testTypeCase5_ZeroUriOnly() { .setExtensionUriReference(0) .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:test:case5", lookup.typeAnchorMap.get(1).urn()); @@ -588,31 +588,31 @@ void testTypeCase5_ZeroUriOnly() { @Test void testTypeUriToUrnFallbackWorks() { // Test the same logic but for types instead of functions - BidiMap uriUrnMap = new BidiMap<>(); + final BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put("http://example.com/types/test", "extension:types:mapped"); - SimpleExtension.ExtensionCollection extensionCollection = + final SimpleExtension.ExtensionCollection extensionCollection = SimpleExtension.ExtensionCollection.builder().uriUrnMap(uriUrnMap).build(); - SimpleExtensionURI uriProto = + final SimpleExtensionURI uriProto = SimpleExtensionURI.newBuilder() .setExtensionUriAnchor(1) .setUri("http://example.com/types/test") .build(); - SimpleExtensionDeclaration.ExtensionType type = + final SimpleExtensionDeclaration.ExtensionType type = SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(1) .setName("legacy_type") .setExtensionUriReference(1) // References the URI anchor .build(); - SimpleExtensionDeclaration decl = + final SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder().setExtensionType(type).build(); - Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); + final Plan plan = Plan.newBuilder().addExtensionUris(uriProto).addExtensions(decl).build(); - ImmutableExtensionLookup lookup = + final ImmutableExtensionLookup lookup = ImmutableExtensionLookup.builder(extensionCollection).from(plan).build(); assertEquals("extension:types:mapped", lookup.typeAnchorMap.get(1).urn()); diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index abb4008d5..c78186a72 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -32,28 +32,29 @@ class TypeExtensionTest { final SimpleExtension.ExtensionCollection extensionCollection; { - String path = "/extensions/custom_extensions.yaml"; - InputStream inputStream = this.getClass().getResourceAsStream(path); + final String path = "/extensions/custom_extensions.yaml"; + final InputStream inputStream = this.getClass().getResourceAsStream(path); extensionCollection = SimpleExtension.load(path, inputStream); } final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); - Type customType1 = b.userDefinedType(URN, "customType1"); - Type customType2 = b.userDefinedType(URN, "customType2"); + final Type customType1 = b.userDefinedType(URN, "customType1"); + final Type customType2 = b.userDefinedType(URN, "customType2"); final PlanProtoConverter planProtoConverter = new PlanProtoConverter(); final ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensionCollection); @Test void roundtripCustomType() { // CREATE TABLE example (custom_type_column custom_type1, i64_column BIGINT); - List tableName = Stream.of("example").collect(Collectors.toList()); - List columnNames = + final List tableName = Stream.of("example").collect(Collectors.toList()); + final List columnNames = Stream.of("custom_type_column", "i64_column").collect(Collectors.toList()); - List types = Stream.of(customType1, R.I64).collect(Collectors.toList()); + final List types = + Stream.of(customType1, R.I64).collect(Collectors.toList()); // SELECT custom_type_column, scalar1(custom_type_column), scalar2(i64_column) // FROM example - Plan plan = + final Plan plan = b.plan( b.root( b.project( @@ -70,20 +71,20 @@ void roundtripCustomType() { .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); - Plan planReturned = protoPlanConverter.from(protoPlan); + final io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + final Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } @Test void roundtripNumberedAnyTypes() { - List tableName = Stream.of("example").collect(Collectors.toList()); - List columnNames = + final List tableName = Stream.of("example").collect(Collectors.toList()); + final List columnNames = Stream.of("array_i64_type_column", "array_i64_column").collect(Collectors.toList()); - List types = + final List types = Stream.of(REQUIRED.list(R.I64)).collect(Collectors.toList()); - Plan plan = + final Plan plan = b.plan( b.root( b.project( @@ -93,8 +94,8 @@ void roundtripNumberedAnyTypes() { URN, "array_index:list_i64", R.I64, b.fieldReference(input, 0))) .collect(Collectors.toList()), b.namedScan(tableName, columnNames, types)))); - io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); - Plan planReturned = protoPlanConverter.from(protoPlan); + final io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + final Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); } } diff --git a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java index e96b8133f..ad05229d7 100644 --- a/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java +++ b/core/src/test/java/io/substrait/extension/UriUrnMigrationEndToEndTest.java @@ -27,18 +27,18 @@ class UriUrnMigrationEndToEndTest { /** Load a proto Plan from a JSON resource file using JsonFormat */ - private Plan loadPlanFromJson(String resourcePath) throws IOException { + private Plan loadPlanFromJson(final String resourcePath) throws IOException { try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream(resourcePath)) { if (inputStream == null) { throw new IOException("Resource not found: " + resourcePath); } - String jsonContent = + final String jsonContent = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) .lines() .collect(Collectors.joining("\n")); - Plan.Builder planBuilder = Plan.newBuilder(); + final Plan.Builder planBuilder = Plan.newBuilder(); JsonFormat.parser().merge(jsonContent, planBuilder); return planBuilder.build(); } @@ -48,7 +48,7 @@ private Plan loadPlanFromJson(String resourcePath) throws IOException { void testUriUrnMigrationEndToEnd() throws IOException { // List of (inputPath, expectedPath, extensionCollection) tuples - List testCases = + final List testCases = Arrays.asList( new String[] { "uri-urn-migration/uri-only-input-plan.json", @@ -71,20 +71,20 @@ void testUriUrnMigrationEndToEnd() throws IOException { "uri-urn-migration/zero-urn-resolution-expected-plan.json" }); - for (String[] testCase : testCases) { - String inputPath = testCase[0]; - String expectedPath = testCase[1]; + for (final String[] testCase : testCases) { + final String inputPath = testCase[0]; + final String expectedPath = testCase[1]; - Plan inputPlan = loadPlanFromJson(inputPath); - Plan expectedPlan = loadPlanFromJson(expectedPath); + final Plan inputPlan = loadPlanFromJson(inputPath); + final Plan expectedPlan = loadPlanFromJson(expectedPath); - ProtoPlanConverter protoToPojo = + final ProtoPlanConverter protoToPojo = new ProtoPlanConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); - io.substrait.plan.Plan pojoPlan = protoToPojo.from(inputPlan); + final io.substrait.plan.Plan pojoPlan = protoToPojo.from(inputPlan); - PlanProtoConverter pojoToProto = + final PlanProtoConverter pojoToProto = new PlanProtoConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); - Plan actualPlan = pojoToProto.toProto(pojoPlan); + final Plan actualPlan = pojoToProto.toProto(pojoPlan); assertEquals(expectedPlan, actualPlan); } @@ -92,12 +92,12 @@ void testUriUrnMigrationEndToEnd() throws IOException { @Test void testUnresolvableUriThrowsException() throws IOException { - Plan inputPlan = loadPlanFromJson("uri-urn-migration/unresolvable-uri-plan.json"); + final Plan inputPlan = loadPlanFromJson("uri-urn-migration/unresolvable-uri-plan.json"); - ProtoPlanConverter protoToPojo = + final ProtoPlanConverter protoToPojo = new ProtoPlanConverter(DefaultExtensionCatalog.DEFAULT_COLLECTION); - IllegalStateException exception = + final IllegalStateException exception = assertThrows( IllegalStateException.class, () -> { diff --git a/core/src/test/java/io/substrait/extension/UrnValidationTest.java b/core/src/test/java/io/substrait/extension/UrnValidationTest.java index 55541c2d5..612a45a6e 100644 --- a/core/src/test/java/io/substrait/extension/UrnValidationTest.java +++ b/core/src/test/java/io/substrait/extension/UrnValidationTest.java @@ -11,8 +11,9 @@ class UrnValidationTest { @Test void testMissingUrnThrowsException() { - String yamlWithoutUrn = "%YAML 1.2\n" + "---\n" + "scalar_functions:\n" + " - name: test\n"; - IllegalArgumentException exception = + final String yamlWithoutUrn = + "%YAML 1.2\n" + "---\n" + "scalar_functions:\n" + " - name: test\n"; + final IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, () -> SimpleExtension.load("some/uri", yamlWithoutUrn)); assertTrue(exception.getMessage().contains("Extension YAML file must contain a 'urn' field")); @@ -20,13 +21,13 @@ void testMissingUrnThrowsException() { @Test void testInvalidUrnFormatThrowsException() { - String yamlWithInvalidUrn = + final String yamlWithInvalidUrn = "%YAML 1.2\n" + "---\n" + "urn: invalid:format\n" + "scalar_functions:\n" + " - name: test\n"; - IllegalArgumentException exception = + final IllegalArgumentException exception = assertThrows( IllegalArgumentException.class, () -> SimpleExtension.load("some/uri", yamlWithInvalidUrn)); @@ -36,7 +37,7 @@ void testInvalidUrnFormatThrowsException() { @Test void testValidUrnWorks() { - String yamlWithValidUrn = + final String yamlWithValidUrn = "%YAML 1.2\n" + "---\n" + "urn: extension:test:valid\n" @@ -47,13 +48,13 @@ void testValidUrnWorks() { @Test void testUriUrnMapIsPopulated() { - String yamlWithValidUrn = + final String yamlWithValidUrn = "%YAML 1.2\n" + "---\n" + "urn: extension:test:valid\n" + "scalar_functions:\n" + " - name: test\n"; - SimpleExtension.ExtensionCollection collection = + final SimpleExtension.ExtensionCollection collection = SimpleExtension.load("test://uri", yamlWithValidUrn); assertEquals("extension:test:valid", collection.getUrnFromUri("test://uri")); } diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java index bf1704d9c..c67356607 100644 --- a/core/src/test/java/io/substrait/relation/AggregateRelTest.java +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -22,13 +22,13 @@ class AggregateRelTest extends TestBase { public static io.substrait.proto.NamedStruct createSchema() { - io.substrait.proto.Type i32Type = + final io.substrait.proto.Type i32Type = io.substrait.proto.Type.newBuilder() .setI32(io.substrait.proto.Type.I32.getDefaultInstance()) .build(); // Build a NamedStruct schema with two fields: col1, col2 - io.substrait.proto.Type.Struct structType = + final io.substrait.proto.Type.Struct structType = io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build(); return io.substrait.proto.NamedStruct.newBuilder() @@ -38,16 +38,16 @@ public static io.substrait.proto.NamedStruct createSchema() { .build(); } - public static io.substrait.proto.Expression createFieldReference(int col) { + public static io.substrait.proto.Expression createFieldReference(final int col) { // Build a ReferenceSegment that refers to struct field col - Expression.ReferenceSegment seg1 = + final Expression.ReferenceSegment seg1 = Expression.ReferenceSegment.newBuilder() .setStructField( Expression.ReferenceSegment.StructField.newBuilder().setField(col).build()) .build(); // Build a FieldReference that uses the directReference and a rootReference - Expression.FieldReference fieldRef1 = + final Expression.FieldReference fieldRef1 = Expression.FieldReference.newBuilder() .setDirectReference(seg1) .setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()) @@ -59,51 +59,51 @@ public static io.substrait.proto.Expression createFieldReference(int col) { @Test void testDeprecatedGroupingExpressionConversion() { - Expression col1Ref = createFieldReference(0); - Expression col2Ref = createFieldReference(1); + final Expression col1Ref = createFieldReference(0); + final Expression col2Ref = createFieldReference(1); - AggregateRel.Grouping grouping = + final AggregateRel.Grouping grouping = AggregateRel.Grouping.newBuilder() .addGroupingExpressions(col1Ref) // deprecated proto form .addGroupingExpressions(col2Ref) .build(); // Build an input ReadRel - ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + final ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); // Build the AggregateRel with the new grouping_expressions field - AggregateRel aggrProto = + final AggregateRel aggrProto = AggregateRel.newBuilder() .setInput(Rel.newBuilder().setRead(readProto)) .addGroupings(grouping) .build(); - Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); - ProtoRelConverter converter = new ProtoRelConverter(functionLookup); - io.substrait.relation.Rel resultRel = converter.from(relProto); + final Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + final ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + final io.substrait.relation.Rel resultRel = converter.from(relProto); assertTrue(resultRel instanceof Aggregate); - Aggregate agg = (Aggregate) resultRel; + final Aggregate agg = (Aggregate) resultRel; assertEquals(1, agg.getGroupings().size()); assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); } @Test void testAggregateWithSingleGrouping() { - Expression col1Ref = createFieldReference(0); - Expression col2Ref = createFieldReference(1); + final Expression col1Ref = createFieldReference(0); + final Expression col2Ref = createFieldReference(1); - AggregateRel.Grouping grouping = + final AggregateRel.Grouping grouping = AggregateRel.Grouping.newBuilder() .addExpressionReferences(0) .addExpressionReferences(1) .build(); // Build an input ReadRel - ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + final ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); // Build the AggregateRel with the new grouping_expressions field - AggregateRel aggrProto = + final AggregateRel aggrProto = AggregateRel.newBuilder() .setInput(Rel.newBuilder().setRead(readProto)) .addGroupingExpressions(col1Ref) @@ -111,35 +111,35 @@ void testAggregateWithSingleGrouping() { .addGroupings(grouping) .build(); - Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); - ProtoRelConverter converter = new ProtoRelConverter(functionLookup); - io.substrait.relation.Rel resultRel = converter.from(relProto); + final Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + final ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + final io.substrait.relation.Rel resultRel = converter.from(relProto); assertTrue(resultRel instanceof Aggregate); - Aggregate agg = (Aggregate) resultRel; + final Aggregate agg = (Aggregate) resultRel; assertEquals(1, agg.getGroupings().size()); assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); } @Test void testAggregateWithMultipleGroupings() { - Expression col1Ref = createFieldReference(0); - Expression col2Ref = createFieldReference(1); + final Expression col1Ref = createFieldReference(0); + final Expression col2Ref = createFieldReference(1); - AggregateRel.Grouping grouping1 = + final AggregateRel.Grouping grouping1 = AggregateRel.Grouping.newBuilder() .addExpressionReferences(0) // new proto form .addExpressionReferences(1) .build(); - AggregateRel.Grouping grouping2 = + final AggregateRel.Grouping grouping2 = AggregateRel.Grouping.newBuilder().addExpressionReferences(1).build(); // Build an input ReadRel - ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + final ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); // Build the AggregateRel with the new grouping_expressions field - AggregateRel aggrProto = + final AggregateRel aggrProto = AggregateRel.newBuilder() .setInput(Rel.newBuilder().setRead(readProto)) .addGroupingExpressions(col1Ref) @@ -148,12 +148,12 @@ void testAggregateWithMultipleGroupings() { .addGroupings(grouping2) .build(); - Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); - ProtoRelConverter converter = new ProtoRelConverter(functionLookup); - io.substrait.relation.Rel resultRel = converter.from(relProto); + final Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + final ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + final io.substrait.relation.Rel resultRel = converter.from(relProto); assertTrue(resultRel instanceof Aggregate); - Aggregate agg = (Aggregate) resultRel; + final Aggregate agg = (Aggregate) resultRel; assertEquals(2, agg.getGroupings().size()); assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); assertEquals(1, agg.getGroupings().get(1).getExpressions().size()); diff --git a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java index 3876fd2ad..96a53205c 100644 --- a/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java +++ b/core/src/test/java/io/substrait/relation/ProtoRelConverterTest.java @@ -37,16 +37,16 @@ class DefaultAdvancedExtensionTests { final StringHolder enhanced = new StringHolder("ENHANCED"); final StringHolder optimized = new StringHolder("OPTIMIZED"); - Rel emptyAdvancedExtension = relWithExtension(AdvancedExtension.builder().build()); - Rel advancedExtensionWithOptimization = + final Rel emptyAdvancedExtension = relWithExtension(AdvancedExtension.builder().build()); + final Rel advancedExtensionWithOptimization = relWithExtension(AdvancedExtension.builder().addOptimizations(optimized).build()); - Rel advancedExtensionWithEnhancement = + final Rel advancedExtensionWithEnhancement = relWithExtension(AdvancedExtension.builder().enhancement(enhanced).build()); - Rel advancedExtensionWithEnhancementAndOptimization = + final Rel advancedExtensionWithEnhancementAndOptimization = relWithExtension( AdvancedExtension.builder().enhancement(enhanced).addOptimizations(optimized).build()); - Rel relWithExtension(AdvancedExtension advancedExtension) { + Rel relWithExtension(final AdvancedExtension advancedExtension) { return NamedScan.builder() .from(commonTable) .commonExtension(advancedExtension) @@ -158,36 +158,37 @@ class DetailsTest { @Test void extensionLeaf() { - Rel rel = ExtensionLeaf.from(new StringHolder("DETAILS")).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel rel = ExtensionLeaf.from(new StringHolder("DETAILS")).build(); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + final Rel relReturned = protoRelConverter.from(protoRel); assertNotEquals(rel, relReturned); } @Test void extensionSingle() { - Rel rel = ExtensionSingle.from(new StringHolder("DETAILS"), commonTable).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel rel = ExtensionSingle.from(new StringHolder("DETAILS"), commonTable).build(); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + final Rel relReturned = protoRelConverter.from(protoRel); assertNotEquals(rel, relReturned); } @Test void extensionMulti() { - Rel rel = ExtensionMulti.from(new StringHolder("DETAILS"), commonTable, commonTable).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel rel = + ExtensionMulti.from(new StringHolder("DETAILS"), commonTable, commonTable).build(); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + final Rel relReturned = protoRelConverter.from(protoRel); assertNotEquals(rel, relReturned); } @Test void extensionTable() { - Rel rel = ExtensionTable.from(new StringHolder("DETAILS")).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel rel = ExtensionTable.from(new StringHolder("DETAILS")).build(); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + final Rel relReturned = protoRelConverter.from(protoRel); assertNotEquals(rel, relReturned); } @@ -197,8 +198,8 @@ void extensionTable() { @Nested class HintsTest { - Stats createStats(boolean includeEmptyOptimization) { - ImmutableStats.Builder builder = Stats.builder(); + Stats createStats(final boolean includeEmptyOptimization) { + final ImmutableStats.Builder builder = Stats.builder(); builder.rowCount(42).recordSize(42); if (includeEmptyOptimization) { builder.extension(AdvancedExtension.builder().addOptimizations().build()); @@ -220,8 +221,8 @@ SavedComputation createSavedComputation() { .build(); } - RuntimeConstraint createRuntimeConstraint(boolean includeEmptyOptimization) { - ImmutableRuntimeConstraint.Builder builder = RuntimeConstraint.builder(); + RuntimeConstraint createRuntimeConstraint(final boolean includeEmptyOptimization) { + final ImmutableRuntimeConstraint.Builder builder = RuntimeConstraint.builder(); if (includeEmptyOptimization) { builder.extension(AdvancedExtension.builder().addOptimizations().build()); } @@ -230,7 +231,7 @@ RuntimeConstraint createRuntimeConstraint(boolean includeEmptyOptimization) { @Test void relWithCompleteHint() { - Hint test = + final Hint test = Hint.builder() .alias("TestHint") .addAllOutputNames(Arrays.asList("Hint 1", "Hint 2")) @@ -242,15 +243,15 @@ void relWithCompleteHint() { .runtimeConstraint(createRuntimeConstraint(true)) .build(); - Rel relWithCompleteHint = NamedScan.builder().from(commonTable).hint(test).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithCompleteHint); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel relWithCompleteHint = NamedScan.builder().from(commonTable).hint(test).build(); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithCompleteHint); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(relWithCompleteHint, relReturned); } @Test void relWithLoadedComputationHint() { - Hint test = + final Hint test = Hint.builder() .alias("TestHint") .addAllOutputNames(Arrays.asList("Hint 1", "Hint 2")) @@ -260,15 +261,17 @@ void relWithLoadedComputationHint() { .runtimeConstraint(createRuntimeConstraint(false)) .build(); - Rel relWithLoadedComputationHint = NamedScan.builder().from(commonTable).hint(test).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithLoadedComputationHint); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel relWithLoadedComputationHint = + NamedScan.builder().from(commonTable).hint(test).build(); + final io.substrait.proto.Rel protoRel = + relProtoConverter.toProto(relWithLoadedComputationHint); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(relWithLoadedComputationHint, relReturned); } @Test void relWithSavedComputationHint() { - Hint test = + final Hint test = Hint.builder() .alias("TestHint") .addAllOutputNames(Arrays.asList("Hint 1", "Hint 2")) @@ -278,18 +281,20 @@ void relWithSavedComputationHint() { .runtimeConstraint(createRuntimeConstraint(false)) .build(); - Rel relWithSavedComputationHint = NamedScan.builder().from(commonTable).hint(test).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithSavedComputationHint); - Rel relReturned = protoRelConverter.from(protoRel); + final Rel relWithSavedComputationHint = + NamedScan.builder().from(commonTable).hint(test).build(); + final io.substrait.proto.Rel protoRel = + relProtoConverter.toProto(relWithSavedComputationHint); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(relWithSavedComputationHint, relReturned); } @Test void relWithMinimalHint() { - Hint test = Hint.builder().build(); - Rel relWithMinimalHint = NamedScan.builder().from(commonTable).hint(test).build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithMinimalHint); - Rel relReturned = protoRelConverter.from(protoRel); + final Hint test = Hint.builder().build(); + final Rel relWithMinimalHint = NamedScan.builder().from(commonTable).hint(test).build(); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(relWithMinimalHint); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(relWithMinimalHint, relReturned); } } diff --git a/core/src/test/java/io/substrait/relation/SetTest.java b/core/src/test/java/io/substrait/relation/SetTest.java index 8c857fb3a..622ed124a 100644 --- a/core/src/test/java/io/substrait/relation/SetTest.java +++ b/core/src/test/java/io/substrait/relation/SetTest.java @@ -19,24 +19,24 @@ class SetTest { @Test void deriveRecordTypeNullability() { - List names = + final List names = Arrays.asList("col1", "col2", "col3", "col4", "col5", "col6", "col7", "col8"); // From https://substrait.io/relations/logical_relations/#output-type-derivation-examples - EmptyScan input1 = + final EmptyScan input1 = EmptyScan.builder() .initialSchema(NamedStruct.of(names, getStruct(R, R, R, R, N, N, N, N))) .build(); - EmptyScan input2 = + final EmptyScan input2 = EmptyScan.builder() .initialSchema(NamedStruct.of(names, getStruct(R, R, N, N, R, R, N, N))) .build(); - EmptyScan input3 = + final EmptyScan input3 = EmptyScan.builder() .initialSchema(NamedStruct.of(names, getStruct(R, N, R, N, R, N, R, N))) .build(); - Map expecteds = new HashMap<>(); + final Map expecteds = new HashMap<>(); expecteds.put(Set.SetOp.MINUS_PRIMARY, getStruct(R, R, R, R, N, N, N, N)); expecteds.put(Set.SetOp.MINUS_PRIMARY_ALL, getStruct(R, R, R, R, N, N, N, N)); expecteds.put(Set.SetOp.MINUS_MULTISET, getStruct(R, R, R, R, N, N, N, N)); @@ -52,7 +52,7 @@ void deriveRecordTypeNullability() { if (setOp == Set.SetOp.UNKNOWN) { return; } - Type.Struct expected = expecteds.get(setOp); + final Type.Struct expected = expecteds.get(setOp); assertNotNull(expected, "Missing expected record type for " + setOp); assertEquals( expected, @@ -65,7 +65,7 @@ void deriveRecordTypeNullability() { }); } - private Type.Struct getStruct(Type... types) { + private Type.Struct getStruct(final Type... types) { return Type.Struct.builder().addFields(types).nullable(false).build(); } } diff --git a/core/src/test/java/io/substrait/relation/SpecVersionTest.java b/core/src/test/java/io/substrait/relation/SpecVersionTest.java index 669e30bdd..c2a306868 100644 --- a/core/src/test/java/io/substrait/relation/SpecVersionTest.java +++ b/core/src/test/java/io/substrait/relation/SpecVersionTest.java @@ -10,7 +10,7 @@ class SpecVersionTest { @Test void testSubstraitVersionDefaultValues() { - Version version = Version.DEFAULT_VERSION; + final Version version = Version.DEFAULT_VERSION; assertNotNull(version.getMajor()); assertNotNull(version.getMinor()); diff --git a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java index df3b0a14c..f1799ff0f 100644 --- a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java +++ b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java @@ -24,7 +24,7 @@ class VirtualTableScanTest extends TestBase { @Test void check() { - VirtualTableScan virtualTableScan = + final VirtualTableScan virtualTableScan = ImmutableVirtualTableScan.builder() .initialSchema( NamedStruct.of( @@ -67,7 +67,7 @@ void check() { @Test void checkValidRowsWithSimpleTypes() { // Test with simple types and multiple rows - VirtualTableScan virtualTableScan = + final VirtualTableScan virtualTableScan = ImmutableVirtualTableScan.builder() .initialSchema( NamedStruct.of( @@ -172,9 +172,9 @@ void checkInvalidNullabilityMismatch() { } private Map mapOf( - Expression.Literal key, Expression.Literal value) { + final Expression.Literal key, final Expression.Literal value) { // Map.of() comes only in Java 9 and the "core" module is on Java 8 - HashMap map = new HashMap<>(); + final HashMap map = new HashMap<>(); map.put(key, value); return map; } diff --git a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java index e9c39ccf5..4b51b918e 100644 --- a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java +++ b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java @@ -71,7 +71,7 @@ void derivationExpression() { "L1=1\nFIXEDCHAR"); } - private void simpleTests(ParseToPojo.Visitor v) { + private void simpleTests(final ParseToPojo.Visitor v) { test(v, r.I8, "I8"); test(v, r.I16, "I16"); test(v, r.I32, "I32"); @@ -90,7 +90,7 @@ private void simpleTests(ParseToPojo.Visitor v) { test(v, n.userDefined(URN, "foo"), "u!foo?"); } - private void compoundTests(ParseToPojo.Visitor v) { + private void compoundTests(final ParseToPojo.Visitor v) { test(v, r.fixedChar(1), "FIXEDCHAR<1>"); test(v, r.varChar(2), "VARCHAR<2>"); test(v, r.fixedBinary(3), "FIXEDBINARY<3>"); @@ -109,7 +109,7 @@ private void compoundTests(ParseToPojo.Visitor v) { test(v, n.map(r.I16, n.I8), "MAP?"); } - private void parameterizedTests(ParseToPojo.Visitor v) { + private void parameterizedTests(final ParseToPojo.Visitor v) { test(v, pn.listE(pr.parameter("K")), "List?"); test(v, pr.structE(r.I8, r.I16, n.I8, pr.parameter("K")), "STRUCT"); test(v, pr.parameter("any"), "any"); @@ -122,7 +122,8 @@ private void parameterizedTests(ParseToPojo.Visitor v) { test(v, pr.decimalE("14", "S"), "DECIMAL<14, S>"); } - private static void test(ParseToPojo.Visitor visitor, TypeExpression expected, String toParse) { + private static void test( + final ParseToPojo.Visitor visitor, final TypeExpression expected, final String toParse) { assertEquals(expected, TypeStringParser.parse(toParse, visitor)); } } diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java index 491428abb..0fb41f7b3 100644 --- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java @@ -21,22 +21,23 @@ class AggregateRoundtripTest extends TestBase { - private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) { - Expression.DecimalLiteral expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - Expression.StructLiteral literal = + private void assertAggregateRoundtrip(final Expression.AggregationInvocation invocation) { + final Expression.DecimalLiteral expression = + ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); + final Expression.StructLiteral literal = Expression.StructLiteral.builder().addFields(expression).build(); - io.substrait.relation.ImmutableVirtualTableScan input = + final io.substrait.relation.ImmutableVirtualTableScan input = VirtualTableScan.builder() .initialSchema(NamedStruct.of(Arrays.asList("decimal"), R.struct(R.decimal(10, 2)))) .addRows(literal) .build(); - ExtensionCollector functionCollector = new ExtensionCollector(); - RelProtoConverter to = new RelProtoConverter(functionCollector); - io.substrait.extension.SimpleExtension.ExtensionCollection extensions = + final ExtensionCollector functionCollector = new ExtensionCollector(); + final RelProtoConverter to = new RelProtoConverter(functionCollector); + final io.substrait.extension.SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection; - ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions); + final ProtoRelConverter from = new ProtoRelConverter(functionCollector, extensions); - io.substrait.relation.ImmutableMeasure measure = + final io.substrait.relation.ImmutableMeasure measure = Aggregate.Measure.builder() .function( AggregateFunctionInvocation.builder() @@ -61,9 +62,9 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio .build()) .build(); - io.substrait.relation.ImmutableAggregate aggRel = + final io.substrait.relation.ImmutableAggregate aggRel = Aggregate.builder().input(input).measures(Arrays.asList(measure)).build(); - io.substrait.proto.Rel protoAggRel = to.toProto(aggRel); + final io.substrait.proto.Rel protoAggRel = to.toProto(aggRel); assertEquals( protoAggRel.getAggregate().getMeasuresList().get(0).getMeasure().getInvocation(), invocation.toProto()); @@ -72,7 +73,8 @@ private void assertAggregateRoundtrip(Expression.AggregationInvocation invocatio @Test void aggregateInvocationRoundtrip() { - for (Expression.AggregationInvocation invocation : Expression.AggregationInvocation.values()) { + for (final Expression.AggregationInvocation invocation : + Expression.AggregationInvocation.values()) { assertAggregateRoundtrip(invocation); } } diff --git a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java index afd8a494d..33b3f0cf7 100644 --- a/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ConsistentPartitionWindowRelRoundtripTest.java @@ -17,16 +17,16 @@ class ConsistentPartitionWindowRelRoundtripTest extends TestBase { @Test void consistentPartitionWindowRoundtripSingle() { - SimpleExtension.WindowFunctionVariant windowFunctionDeclaration = + final SimpleExtension.WindowFunctionVariant windowFunctionDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); - Rel input = + final Rel input = b.namedScan( Arrays.asList("test"), Arrays.asList("a", "b", "c"), Arrays.asList(R.I64, R.I16, R.I32)); - Rel rel1 = + final Rel rel1 = ConsistentPartitionWindow.builder() .input(input) .windowFunctions( @@ -59,8 +59,8 @@ void consistentPartitionWindowRoundtripSingle() { .build())) .build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); - io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); + final io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); assertEquals(rel1, rel2); // Make sure that the record types match I64, I16, I32 and then the I64 from the window @@ -70,20 +70,20 @@ void consistentPartitionWindowRoundtripSingle() { @Test void consistentPartitionWindowRoundtripMulti() { - SimpleExtension.WindowFunctionVariant windowFunctionLeadDeclaration = + final SimpleExtension.WindowFunctionVariant windowFunctionLeadDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); - SimpleExtension.WindowFunctionVariant windowFunctionLagDeclaration = + final SimpleExtension.WindowFunctionVariant windowFunctionLagDeclaration = defaultExtensionCollection.getWindowFunction( SimpleExtension.FunctionAnchor.of( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "lead:any")); - Rel input = + final Rel input = b.namedScan( Arrays.asList("test"), Arrays.asList("a", "b", "c"), Arrays.asList(R.I64, R.I16, R.I32)); - Rel rel1 = + final Rel rel1 = ConsistentPartitionWindow.builder() .input(input) .windowFunctions( @@ -133,8 +133,8 @@ void consistentPartitionWindowRoundtripMulti() { .build())) .build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); - io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel1); + final io.substrait.relation.Rel rel2 = protoRelConverter.from(protoRel); assertEquals(rel1, rel2); // Make sure that the record types match I64, I16, I32 and then the I64 and I64 from the window diff --git a/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java index 19e91aa29..7a0b81247 100644 --- a/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/DdlRelRoundtripTest.java @@ -21,15 +21,15 @@ class DdlRelRoundtripTest extends TestBase { @Test void create() { - NamedStruct schema = + final NamedStruct schema = NamedStruct.of( Stream.of("column1", "column2").collect(Collectors.toList()), R.struct(R.I64, R.I64)); - Expression.StructLiteral defaults = + final Expression.StructLiteral defaults = ExpressionCreator.struct( false, ExpressionCreator.i64(false, 1), ExpressionCreator.i64(false, 2)); - NamedDdl command = + final NamedDdl command = NamedDdl.builder() .tableSchema(schema) .tableDefaults(defaults) @@ -43,23 +43,23 @@ void create() { @Test void alter() { - ProtoRelConverter protoRelConverter = + final ProtoRelConverter protoRelConverter = new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection); - StringHolder detail = new StringHolder("DETAIL"); + final StringHolder detail = new StringHolder("DETAIL"); - NamedStruct schema = + final NamedStruct schema = NamedStruct.of( Stream.of("column1", "column2").collect(Collectors.toList()), R.struct(R.I64, R.I64)); - Expression.StructLiteral defaults = + final Expression.StructLiteral defaults = ExpressionCreator.struct( false, ExpressionCreator.i64(false, 1), ExpressionCreator.i64(false, 2)); - VirtualTableScan virtTable = + final VirtualTableScan virtTable = VirtualTableScan.builder().initialSchema(schema).addRows(defaults).build(); - ExtensionDdl command = + final ExtensionDdl command = ExtensionDdl.builder() .viewDefinition(virtTable) .tableSchema(schema) @@ -69,8 +69,8 @@ void alter() { .object(ExtensionDdl.DdlObject.VIEW) .build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(command); - Rel relReturned = protoRelConverter.from(protoRel); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(command); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(command, relReturned); } } diff --git a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java index 95b06b8fc..847d0ae0f 100644 --- a/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExchangeRelRoundtripTest.java @@ -24,14 +24,14 @@ class ExchangeRelRoundtripTest extends TestBase { @Test void broadcastExchange() { - Rel exchange = BroadcastExchange.builder().input(baseTable).partitionCount(1).build(); + final Rel exchange = BroadcastExchange.builder().input(baseTable).partitionCount(1).build(); verifyRoundTrip(exchange); } @Test void roundRobinExchange() { - Rel exchange = + final Rel exchange = RoundRobinExchange.builder().input(baseTable).exact(true).partitionCount(1).build(); verifyRoundTrip(exchange); @@ -39,7 +39,7 @@ void roundRobinExchange() { @Test void scatterExchange() { - Rel exchange = + final Rel exchange = ScatterExchange.builder() .input(baseTable) .addFields(b.fieldReference(baseTable, 0)) @@ -51,7 +51,7 @@ void scatterExchange() { @Test void singleBucketExchange() { - Rel exchange = + final Rel exchange = SingleBucketExchange.builder() .input(baseTable) .partitionCount(1) @@ -63,7 +63,7 @@ void singleBucketExchange() { @Test void multiBucketExchange() { - Rel exchange = + final Rel exchange = MultiBucketExchange.builder() .input(baseTable) .expression(b.fieldReference(baseTable, 0)) @@ -76,21 +76,21 @@ void multiBucketExchange() { @Test void exchangeWithTargets() { - AbstractExchangeRel.ExchangeTarget target1 = + final AbstractExchangeRel.ExchangeTarget target1 = AbstractExchangeRel.ExchangeTarget.builder() .partitionIds(Arrays.asList(0, 1)) .type(TargetType.Uri.builder().uri("hdfs://example.com/data1").build()) .build(); - AbstractExchangeRel.ExchangeTarget target2 = + final AbstractExchangeRel.ExchangeTarget target2 = AbstractExchangeRel.ExchangeTarget.builder() .partitionIds(Arrays.asList(2, 3)) .type(TargetType.Uri.builder().uri("hdfs://example.com/data2").build()) .build(); - List targets = Arrays.asList(target1, target2); + final List targets = Arrays.asList(target1, target2); - Rel exchange = + final Rel exchange = BroadcastExchange.builder().input(baseTable).targets(targets).partitionCount(1).build(); verifyRoundTrip(exchange); @@ -98,9 +98,10 @@ void exchangeWithTargets() { @Test void nestedExchangeRelations() { - Rel innerExchange = BroadcastExchange.builder().input(baseTable).partitionCount(1).build(); + final Rel innerExchange = + BroadcastExchange.builder().input(baseTable).partitionCount(1).build(); - Rel outerExchange = + final Rel outerExchange = RoundRobinExchange.builder().input(innerExchange).exact(false).partitionCount(1).build(); verifyRoundTrip(outerExchange); diff --git a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java index 127ac6c00..3e13c51ad 100644 --- a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java @@ -18,11 +18,11 @@ class ExpandRelRoundtripTest extends TestBase { Stream.of("column1", "column2").collect(Collectors.toList()), Stream.of(R.I64, R.I64).collect(Collectors.toList())); - private Expand.ExpandField getConsistentField(int index) { + private Expand.ExpandField getConsistentField(final int index) { return Expand.ConsistentField.builder().expression(b.fieldReference(input, index)).build(); } - private Expand.ExpandField getSwitchingField(List indexes) { + private Expand.ExpandField getSwitchingField(final List indexes) { return Expand.SwitchingField.builder() .addAllDuplicates( indexes.stream() @@ -33,7 +33,7 @@ private Expand.ExpandField getSwitchingField(List indexes) { @Test void expandConsistent() { - Rel rel = + final Rel rel = Expand.builder() .from(b.expand(__ -> Collections.emptyList(), input)) .hint( @@ -50,7 +50,7 @@ void expandConsistent() { @Test void expandSwitching() { - Rel rel = + final Rel rel = Expand.builder() .from(b.expand(__ -> Collections.emptyList(), input)) .hint(Hint.builder().addAllOutputNames(Arrays.asList("name1", "name2")).build()) diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index ccabb6188..6aaa47ba9 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -67,17 +67,17 @@ class ExtensionRoundtripTest extends TestBase { .build(); @Override - protected void verifyRoundTrip(Rel rel) { - RelProtoConverter relProtoConverter = + protected void verifyRoundTrip(final Rel rel) { + final RelProtoConverter relProtoConverter = new StringHolderHandlingRelProtoConverter(functionCollector); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); - Rel relReturned = protoRelConverter.from(protoRel); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } @Test void virtualTable() { - Rel rel = + final Rel rel = VirtualTableScan.builder() .initialSchema(NamedStruct.of(Collections.emptyList(), R.struct())) .addRows(Expression.StructLiteral.builder().fields(Collections.emptyList()).build()) @@ -89,7 +89,7 @@ void virtualTable() { @Test void localFiles() { - Rel rel = + final Rel rel = LocalFiles.builder() .initialSchema( NamedStruct.of( @@ -102,7 +102,7 @@ void localFiles() { @Test void namedScan() { - Rel rel = + final Rel rel = NamedScan.builder() .from( b.namedScan( @@ -115,13 +115,13 @@ void namedScan() { @Test void extensionTable() { - Rel rel = ExtensionTable.from(detail).build(); + final Rel rel = ExtensionTable.from(detail).build(); verifyRoundTrip(rel); } @Test void filter() { - Rel rel = + final Rel rel = Filter.builder() .from(b.filter(__ -> b.bool(true), commonTable)) .commonExtension(commonExtension) @@ -132,7 +132,7 @@ void filter() { @Test void fetch() { - Rel rel = + final Rel rel = Fetch.builder() .from(b.fetch(1, 2, commonTable)) .commonExtension(commonExtension) @@ -143,7 +143,7 @@ void fetch() { @Test void aggregate() { - Rel rel = + final Rel rel = Aggregate.builder() .from(b.aggregate(b::grouping, __ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) @@ -154,7 +154,7 @@ void aggregate() { @Test void sort() { - Rel rel = + final Rel rel = Sort.builder() .from(b.sort(__ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) @@ -165,7 +165,7 @@ void sort() { @Test void join() { - Rel rel = + final Rel rel = Join.builder() .from(b.innerJoin(__ -> b.bool(true), commonTable, commonTable)) .commonExtension(commonExtension) @@ -177,9 +177,9 @@ void join() { @Test void hashJoin() { // with empty keys - List leftEmptyKeys = Collections.emptyList(); - List rightEmptyKeys = Collections.emptyList(); - Rel relWithoutKeys = + final List leftEmptyKeys = Collections.emptyList(); + final List rightEmptyKeys = Collections.emptyList(); + final Rel relWithoutKeys = HashJoin.builder() .from( b.hashJoin( @@ -197,9 +197,9 @@ void hashJoin() { @Test void mergeJoin() { // with empty keys - List leftEmptyKeys = Collections.emptyList(); - List rightEmptyKeys = Collections.emptyList(); - Rel relWithoutKeys = + final List leftEmptyKeys = Collections.emptyList(); + final List rightEmptyKeys = Collections.emptyList(); + final Rel relWithoutKeys = MergeJoin.builder() .from( b.mergeJoin( @@ -216,7 +216,7 @@ void mergeJoin() { @Test void nestedLoopJoin() { - Rel rel = + final Rel rel = NestedLoopJoin.builder() .from( b.nestedLoopJoin( @@ -229,7 +229,7 @@ void nestedLoopJoin() { @Test void project() { - Rel rel = + final Rel rel = Project.builder() .from(b.project(__ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) @@ -240,7 +240,7 @@ void project() { @Test void expand() { - Rel rel = + final Rel rel = Expand.builder() .from(b.expand(__ -> Collections.emptyList(), commonTable)) .commonExtension(commonExtension) @@ -250,7 +250,7 @@ void expand() { @Test void set() { - Rel rel = + final Rel rel = Set.builder() .from(b.set(Set.SetOp.UNION_ALL, commonTable)) .commonExtension(commonExtension) @@ -261,13 +261,14 @@ void set() { @Test void extensionSingleRel() { - Rel rel = ExtensionSingle.from(detail, commonTable).commonExtension(commonExtension).build(); + final Rel rel = + ExtensionSingle.from(detail, commonTable).commonExtension(commonExtension).build(); verifyRoundTrip(rel); } @Test void extensionMultiRel() { - Rel rel = + final Rel rel = ExtensionMulti.from(detail, commonTable, commonTable) .commonExtension(commonExtension) .build(); @@ -276,13 +277,13 @@ void extensionMultiRel() { @Test void extensionLeafRel() { - Rel rel = ExtensionLeaf.from(detail).commonExtension(commonExtension).build(); + final Rel rel = ExtensionLeaf.from(detail).commonExtension(commonExtension).build(); verifyRoundTrip(rel); } @Test void cross() { - Rel rel = + final Rel rel = Cross.builder() .from(b.cross(commonTable, commonTable)) .commonExtension(commonExtension) @@ -310,7 +311,7 @@ class ExtensionThroughExpression { @Test void scalarSubquery() { - Project rel = + final Project rel = b.project( input -> Stream.of( @@ -326,7 +327,7 @@ void scalarSubquery() { @Test void inPredicate() { - Project rel = + final Project rel = b.project( input -> Stream.of( @@ -341,7 +342,7 @@ void inPredicate() { @Test void setPredicate() { - Project rel = + final Project rel = b.project( input -> Stream.of( diff --git a/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java index a04f0bad4..63ec9fc3b 100644 --- a/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/FieldReferenceRoundtripTest.java @@ -29,7 +29,7 @@ class FieldReferenceRoundtripTest extends TestBase { @Test void simpleStructFieldReference() { // Test simple root struct field reference via projection - Rel projection = + final Rel projection = Project.builder().input(baseTable).addExpressions(b.fieldReference(baseTable, 0)).build(); verifyRoundTrip(projection); @@ -38,7 +38,7 @@ void simpleStructFieldReference() { @Test void multipleFieldReferences() { // Test multiple field references in same projection - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( @@ -53,9 +53,10 @@ void multipleFieldReferences() { @Test void fieldReferenceInFilter() { // Test field reference in filter condition - Expression condition = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + final Expression condition = + b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -63,9 +64,9 @@ void fieldReferenceInFilter() { @Test void fieldReferenceInComplexExpression() { // Test field reference as part of arithmetic expression - Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + final Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); - Rel projection = Project.builder().input(baseTable).addExpressions(add).build(); + final Rel projection = Project.builder().input(baseTable).addExpressions(add).build(); verifyRoundTrip(projection); } @@ -73,13 +74,13 @@ void fieldReferenceInComplexExpression() { @Test void fieldReferenceInNestedProjection() { // Test field reference through nested projections - Rel firstProjection = + final Rel firstProjection = Project.builder() .input(baseTable) .addExpressions(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 2)) .build(); - Rel secondProjection = + final Rel secondProjection = Project.builder() .input(firstProjection) .addExpressions(b.fieldReference(firstProjection, 1)) @@ -91,7 +92,7 @@ void fieldReferenceInNestedProjection() { @Test void fieldReferenceAllFields() { // Test referencing all fields - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( @@ -107,12 +108,12 @@ void fieldReferenceAllFields() { @Test void fieldReferenceWithBooleanLogic() { // Test field references in boolean expressions - Expression condition = + final Expression condition = b.and( b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)), b.equal(b.fieldReference(baseTable, 2), b.str("test"))); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -120,10 +121,10 @@ void fieldReferenceWithBooleanLogic() { @Test void fieldReferenceInMultipleArithmetic() { // Test multiple field references in arithmetic - Expression add = b.add(b.fieldReference(baseTable, 1), b.fieldReference(baseTable, 1)); - Expression multiply = b.multiply(add, b.fieldReference(baseTable, 1)); + final Expression add = b.add(b.fieldReference(baseTable, 1), b.fieldReference(baseTable, 1)); + final Expression multiply = b.multiply(add, b.fieldReference(baseTable, 1)); - Rel projection = Project.builder().input(baseTable).addExpressions(multiply).build(); + final Rel projection = Project.builder().input(baseTable).addExpressions(multiply).build(); verifyRoundTrip(projection); } @@ -131,7 +132,7 @@ void fieldReferenceInMultipleArithmetic() { @Test void fieldReferenceReordering() { // Test field reordering through projection (accessing fields out of order) - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( @@ -146,7 +147,7 @@ void fieldReferenceReordering() { @Test void sameFieldReferencedMultipleTimes() { // Test same field referenced multiple times - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( diff --git a/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java index 4ee77133f..dcc633df9 100644 --- a/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/FilterRelRoundtripTest.java @@ -19,9 +19,9 @@ class FilterRelRoundtripTest extends TestBase { @Test void simpleEqualityFilter() { // Filter: WHERE id = 100 - Expression condition = b.equal(b.fieldReference(baseTable, 0), b.i32(100)); + final Expression condition = b.equal(b.fieldReference(baseTable, 0), b.i32(100)); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -29,9 +29,9 @@ void simpleEqualityFilter() { @Test void stringComparisonFilter() { // Filter: WHERE name = 'John' - Expression condition = b.equal(b.fieldReference(baseTable, 2), b.str("John")); + final Expression condition = b.equal(b.fieldReference(baseTable, 2), b.str("John")); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -39,12 +39,12 @@ void stringComparisonFilter() { @Test void andConditionFilter() { // Filter: WHERE id = 10 AND amount = 100.0 - Expression condition = + final Expression condition = b.and( b.equal(b.fieldReference(baseTable, 0), b.i32(10)), b.equal(b.fieldReference(baseTable, 1), b.fp64(100.0))); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -52,12 +52,12 @@ void andConditionFilter() { @Test void orConditionFilter() { // Filter: WHERE id = 5 OR id = 95 - Expression condition = + final Expression condition = b.or( b.equal(b.fieldReference(baseTable, 0), b.i32(5)), b.equal(b.fieldReference(baseTable, 0), b.i32(95))); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -65,15 +65,15 @@ void orConditionFilter() { @Test void complexBooleanFilter() { // Filter: WHERE (id = 10 AND amount = 100) OR status = true - Expression andCondition = + final Expression andCondition = b.and( b.equal(b.fieldReference(baseTable, 0), b.i32(10)), b.equal(b.fieldReference(baseTable, 1), b.fp64(100.0))); - Expression condition = + final Expression condition = b.or(andCondition, b.equal(b.fieldReference(baseTable, 3), b.bool(true))); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -81,9 +81,10 @@ void complexBooleanFilter() { @Test void multipleFieldComparison() { // Filter: WHERE id = amount (comparing two fields) - Expression condition = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 1)); + final Expression condition = + b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 1)); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -91,11 +92,11 @@ void multipleFieldComparison() { @Test void nestedFilters() { // Apply filter on top of another filter - Expression firstCondition = b.equal(b.fieldReference(baseTable, 0), b.i32(10)); - Rel firstFilter = Filter.builder().input(baseTable).condition(firstCondition).build(); + final Expression firstCondition = b.equal(b.fieldReference(baseTable, 0), b.i32(10)); + final Rel firstFilter = Filter.builder().input(baseTable).condition(firstCondition).build(); - Expression secondCondition = b.equal(b.fieldReference(firstFilter, 1), b.fp64(100.0)); - Rel secondFilter = Filter.builder().input(firstFilter).condition(secondCondition).build(); + final Expression secondCondition = b.equal(b.fieldReference(firstFilter, 1), b.fp64(100.0)); + final Rel secondFilter = Filter.builder().input(firstFilter).condition(secondCondition).build(); verifyRoundTrip(secondFilter); } @@ -103,10 +104,10 @@ void nestedFilters() { @Test void filterWithArithmeticExpression() { // Filter: WHERE amount * 2 = 100 - Expression multiply = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0)); - Expression condition = b.equal(multiply, b.fp64(100.0)); + final Expression multiply = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0)); + final Expression condition = b.equal(multiply, b.fp64(100.0)); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -114,9 +115,9 @@ void filterWithArithmeticExpression() { @Test void filterWithBooleanField() { // Filter: WHERE status (direct boolean field) - Expression condition = b.fieldReference(baseTable, 3); + final Expression condition = b.fieldReference(baseTable, 3); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } @@ -124,10 +125,10 @@ void filterWithBooleanField() { @Test void filterWithAddition() { // Filter: WHERE id + id = id (field with itself) - Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); - Expression condition = b.equal(add, b.fieldReference(baseTable, 0)); + final Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + final Expression condition = b.equal(add, b.fieldReference(baseTable, 0)); - Rel filter = Filter.builder().input(baseTable).condition(condition).build(); + final Rel filter = Filter.builder().input(baseTable).condition(condition).build(); verifyRoundTrip(filter); } diff --git a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java index 268729c68..e128284ec 100644 --- a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java @@ -40,7 +40,8 @@ class GenericRoundtripTest extends TestBase { * parameters. If the param generation has failed the {@link UnsupportedTypeGenerationException} e * is populated, and the test will be ignored (kept here for tracking). */ - void roundtripTest(Method m, List paramInst, UnsupportedTypeGenerationException e) + void roundtripTest( + final Method m, final List paramInst, final UnsupportedTypeGenerationException e) throws InvocationTargetException, IllegalAccessException { // If there is an UncoveredTypeGenerationException we ignore this test @@ -49,10 +50,10 @@ void roundtripTest(Method m, List paramInst, UnsupportedTypeGenerationEx } // roundtrip to protobuff and back and check equality - Expression val = (Expression) m.invoke(null, paramInst.toArray(new Object[0])); + final Expression val = (Expression) m.invoke(null, paramInst.toArray(new Object[0])); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = + final ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + final ProtoExpressionConverter from = new ProtoExpressionConverter( null, null, @@ -64,27 +65,27 @@ void roundtripTest(Method m, List paramInst, UnsupportedTypeGenerationEx // Parametrized case generator private static Collection generateInvocations() { - ArrayList invocations = new ArrayList<>(); + final ArrayList invocations = new ArrayList<>(); // We list all public and static methods of ExpressionCreator - List methodsToTest = getMethods(ExpressionCreator.class, true, true); + final List methodsToTest = getMethods(ExpressionCreator.class, true, true); // We generate synthetic input params (for a subset of types we support) - for (Method m : methodsToTest) { + for (final Method m : methodsToTest) { try { invocations.add(arguments(m, instantiateParams(m), null)); - } catch (UnsupportedTypeGenerationException e) { + } catch (final UnsupportedTypeGenerationException e) { invocations.add(arguments(m, null, e)); } } return invocations; } - private static List instantiateParams(Method m) + private static List instantiateParams(final Method m) throws UnsupportedTypeGenerationException { - List l = new ArrayList<>(); - for (Class param : m.getParameterTypes()) { - Object val = valGenerator(param); + final List l = new ArrayList<>(); + for (final Class param : m.getParameterTypes()) { + final Object val = valGenerator(param); if (val == null) { throw new UnsupportedTypeGenerationException( "We can't yet handle generation for type: " + param.getName()); @@ -95,10 +96,10 @@ private static List instantiateParams(Method m) } private static List getMethods( - Class c, boolean filterPublicOnly, boolean filterStaticOnly) { - Method[] allMethods = c.getMethods(); - List selectedMethods = new ArrayList<>(); - for (Method m : allMethods) { + final Class c, final boolean filterPublicOnly, final boolean filterStaticOnly) { + final Method[] allMethods = c.getMethods(); + final List selectedMethods = new ArrayList<>(); + for (final Method m : allMethods) { if ((filterPublicOnly && !Modifier.isPublic(m.getModifiers())) || (filterStaticOnly && !Modifier.isStatic(m.getModifiers()))) { continue; @@ -108,7 +109,7 @@ private static List getMethods( return selectedMethods; } - private static Object valGenerator(Class type) { + private static Object valGenerator(final Class type) { // For each "type" generate some random value if (type.equals(Boolean.TYPE) || type.equals(Boolean.class)) { @@ -146,7 +147,7 @@ private static class UnsupportedTypeGenerationException extends Exception { private static final long serialVersionUID = -8627552468610061245L; - public UnsupportedTypeGenerationException(String s) { + public UnsupportedTypeGenerationException(final String s) { super(s); } } diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java index 9b42136fb..037d7fefc 100644 --- a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -25,8 +25,8 @@ void ifThenNotNullable() { ExpressionCreator.i64(false, 2)); assertFalse(ifRel.getType().nullable()); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = + final ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + final ProtoExpressionConverter from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } @@ -40,8 +40,8 @@ void ifThenNullable() { ExpressionCreator.i64(false, 2)); assertTrue(ifRel.getType().nullable()); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = + final ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + final ProtoExpressionConverter from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(ifRel, from.from(ifRel.accept(to, EmptyVisitationContext.INSTANCE))); } diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index 45e8749ba..ff7191c53 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -25,9 +25,9 @@ class JoinRoundtripTest extends TestBase { @Test void hashJoin() { - List leftKeys = Arrays.asList(0, 1); - List rightKeys = Arrays.asList(2, 0); - Rel relWithoutKeys = + final List leftKeys = Arrays.asList(0, 1); + final List rightKeys = Arrays.asList(2, 0); + final Rel relWithoutKeys = HashJoin.builder() .from(b.hashJoin(leftKeys, rightKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) .build(); @@ -36,9 +36,9 @@ void hashJoin() { @Test void mergeJoin() { - List leftKeys = Arrays.asList(0, 1); - List rightKeys = Arrays.asList(2, 0); - Rel relWithoutKeys = + final List leftKeys = Arrays.asList(0, 1); + final List rightKeys = Arrays.asList(2, 0); + final Rel relWithoutKeys = MergeJoin.builder() .from(b.mergeJoin(leftKeys, rightKeys, MergeJoin.JoinType.INNER, leftTable, rightTable)) .build(); @@ -47,8 +47,8 @@ void mergeJoin() { @Test void nestedLoopJoin() { - List inputRels = Arrays.asList(leftTable, rightTable); - Rel rel = + final List inputRels = Arrays.asList(leftTable, rightTable); + final Rel rel = NestedLoopJoin.builder() .from( b.nestedLoopJoin( diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index bfb417365..2523167ae 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -15,10 +15,10 @@ class LiteralRoundtripTest extends TestBase { @Test void decimal() { - io.substrait.expression.Expression.DecimalLiteral val = + final io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = + final ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); + final ProtoExpressionConverter from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); } diff --git a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java index 7143ef707..af1220ff4 100644 --- a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java @@ -19,8 +19,8 @@ class LocalFilesRoundtripTest extends TestBase { - private void assertLocalFilesRoundtrip(FileOrFiles file) { - io.substrait.relation.ImmutableLocalFiles.Builder builder = + private void assertLocalFilesRoundtrip(final FileOrFiles file) { + final io.substrait.relation.ImmutableLocalFiles.Builder builder = LocalFiles.builder() .initialSchema( NamedStruct.builder() @@ -48,15 +48,15 @@ private void assertLocalFilesRoundtrip(FileOrFiles file) { ExpressionCreator.i32(false, 1))) .ifPresent(builder::filter); - io.substrait.relation.ImmutableLocalFiles localFiles = builder.build(); - io.substrait.proto.Rel protoFileRel = relProtoConverter.toProto(localFiles); + final io.substrait.relation.ImmutableLocalFiles localFiles = builder.build(); + final io.substrait.proto.Rel protoFileRel = relProtoConverter.toProto(localFiles); assertTrue(protoFileRel.getRead().hasFilter()); assertEquals(protoFileRel, relProtoConverter.toProto(protoRelConverter.from(protoFileRel))); } private ImmutableFileOrFiles.Builder setPath( - ImmutableFileOrFiles.Builder builder, - ReadRel.LocalFiles.FileOrFiles.PathTypeCase pathTypeCase) { + final ImmutableFileOrFiles.Builder builder, + final ReadRel.LocalFiles.FileOrFiles.PathTypeCase pathTypeCase) { switch (pathTypeCase) { case URI_PATH: return builder.pathType(FileOrFiles.PathType.URI_PATH).path("path"); @@ -74,8 +74,8 @@ private ImmutableFileOrFiles.Builder setPath( } private ImmutableFileOrFiles.Builder setFileFormat( - ImmutableFileOrFiles.Builder builder, - ReadRel.LocalFiles.FileOrFiles.FileFormatCase fileFormatCase) { + final ImmutableFileOrFiles.Builder builder, + final ReadRel.LocalFiles.FileOrFiles.FileFormatCase fileFormatCase) { switch (fileFormatCase) { case PARQUET: return builder.fileFormat(FileFormat.ParquetReadOptions.builder().build()); diff --git a/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java index 66a0fd417..cd2dca323 100644 --- a/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ProjectRelRoundtripTest.java @@ -19,7 +19,7 @@ class ProjectRelRoundtripTest extends TestBase { @Test void simpleProjection() { // Project single field - Rel projection = + final Rel projection = Project.builder().input(baseTable).addExpressions(b.fieldReference(baseTable, 0)).build(); verifyRoundTrip(projection); @@ -28,7 +28,7 @@ void simpleProjection() { @Test void multipleFieldProjection() { // Project multiple fields - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( @@ -43,9 +43,10 @@ void multipleFieldProjection() { @Test void projectionWithComputedExpression() { // Project with computed expression: col_a + 3 (both I64) - Expression addExpr = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + final Expression addExpr = + b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); - Rel projection = Project.builder().input(baseTable).addExpressions(addExpr).build(); + final Rel projection = Project.builder().input(baseTable).addExpressions(addExpr).build(); verifyRoundTrip(projection); } @@ -53,11 +54,11 @@ void projectionWithComputedExpression() { @Test void projectionWithMultipleComputedExpressions() { // Project with multiple computed expressions - Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); - Expression multiply = + final Expression add = b.add(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 0)); + final Expression multiply = b.multiply(b.fieldReference(baseTable, 1), b.fieldReference(baseTable, 1)); - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( @@ -72,7 +73,7 @@ void projectionWithMultipleComputedExpressions() { @Test void projectionWithLiterals() { // Project with literal values - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions(b.fieldReference(baseTable, 0), b.i32(100), b.str("constant_string")) @@ -84,7 +85,7 @@ void projectionWithLiterals() { @Test void projectionWithAllFields() { // Project all fields (identity projection) - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions( @@ -100,13 +101,13 @@ void projectionWithAllFields() { @Test void nestedProjection() { // Project on top of another projection - Rel firstProjection = + final Rel firstProjection = Project.builder() .input(baseTable) .addExpressions(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 2)) .build(); - Rel secondProjection = + final Rel secondProjection = Project.builder() .input(firstProjection) .addExpressions(b.fieldReference(firstProjection, 1)) @@ -118,9 +119,10 @@ void nestedProjection() { @Test void projectionWithComparison() { // Project with comparison expression: col_a = col_d - Expression comparison = b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 3)); + final Expression comparison = + b.equal(b.fieldReference(baseTable, 0), b.fieldReference(baseTable, 3)); - Rel projection = + final Rel projection = Project.builder() .input(baseTable) .addExpressions(b.fieldReference(baseTable, 0), comparison) @@ -132,9 +134,9 @@ void projectionWithComparison() { @Test void projectionWithCast() { // Project with type cast: CAST(col_d AS BIGINT) - Expression cast = b.cast(b.fieldReference(baseTable, 3), R.I64); + final Expression cast = b.cast(b.fieldReference(baseTable, 3), R.I64); - Rel projection = Project.builder().input(baseTable).addExpressions(cast).build(); + final Rel projection = Project.builder().input(baseTable).addExpressions(cast).build(); verifyRoundTrip(projection); } @@ -142,7 +144,7 @@ void projectionWithCast() { @Test void emptyProjection() { // Project with no expressions (edge case - may produce empty output schema) - Rel projection = Project.builder().input(baseTable).build(); + final Rel projection = Project.builder().input(baseTable).build(); verifyRoundTrip(projection); } diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java index a6d6d764f..59c6969da 100644 --- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java @@ -15,9 +15,9 @@ class ReadRelRoundtripTest extends TestBase { @Test void namedScan() { - List tableName = Stream.of("a_table").collect(Collectors.toList()); - List columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); - List columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList()); + final List tableName = Stream.of("a_table").collect(Collectors.toList()); + final List columnNames = Stream.of("column1", "column2").collect(Collectors.toList()); + final List columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList()); NamedScan namedScan = b.namedScan(tableName, columnNames, columnTypes); namedScan = @@ -33,7 +33,7 @@ void namedScan() { @Test void emptyScan() { - io.substrait.relation.EmptyScan emptyScan = b.emptyScan(); + final io.substrait.relation.EmptyScan emptyScan = b.emptyScan(); verifyRoundTrip(emptyScan); } diff --git a/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java index 9ec30822c..069c8fa53 100644 --- a/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/SortRelRoundtripTest.java @@ -19,13 +19,13 @@ class SortRelRoundtripTest extends TestBase { @Test void simpleSortAscending() { // Sort by id ascending, nulls first - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 0)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -33,13 +33,13 @@ void simpleSortAscending() { @Test void sortAscendingNullsLast() { // Sort by name ascending, nulls last - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -47,13 +47,13 @@ void sortAscendingNullsLast() { @Test void sortDescendingNullsFirst() { // Sort by amount descending, nulls first - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.DESC_NULLS_FIRST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -61,13 +61,13 @@ void sortDescendingNullsFirst() { @Test void sortDescendingNullsLast() { // Sort by timestamp descending, nulls last - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 4)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -75,13 +75,13 @@ void sortDescendingNullsLast() { @Test void sortClustered() { // Sort with clustered direction (no specific order guarantee) - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.CLUSTERED) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -89,19 +89,19 @@ void sortClustered() { @Test void multipleSortFields() { // Sort by category (asc), then amount (desc) - Expression.SortField sortField1 = + final Expression.SortField sortField1 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); - Expression.SortField sortField2 = + final Expression.SortField sortField2 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField1, sortField2).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField1, sortField2).build(); verifyRoundTrip(sort); } @@ -109,25 +109,25 @@ void multipleSortFields() { @Test void sortByThreeFields() { // Sort by category, name, and id - Expression.SortField sortField1 = + final Expression.SortField sortField1 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); - Expression.SortField sortField2 = + final Expression.SortField sortField2 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); - Expression.SortField sortField3 = + final Expression.SortField sortField3 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 0)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); - Rel sort = + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField1, sortField2, sortField3).build(); verifyRoundTrip(sort); @@ -136,15 +136,15 @@ void sortByThreeFields() { @Test void sortByComputedExpression() { // Sort by computed expression: amount * 2 - Expression computedExpr = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0)); + final Expression computedExpr = b.multiply(b.fieldReference(baseTable, 1), b.fp64(2.0)); - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(computedExpr) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -152,13 +152,13 @@ void sortByComputedExpression() { @Test void sortByStringField() { // Sort by string field directly - Expression.SortField sortField = + final Expression.SortField sortField = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); - Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField).build(); verifyRoundTrip(sort); } @@ -166,25 +166,25 @@ void sortByStringField() { @Test void sortWithMixedNullHandling() { // Sort with different null handling for different fields - Expression.SortField sortField1 = + final Expression.SortField sortField1 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); - Expression.SortField sortField2 = + final Expression.SortField sortField2 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 1)) .direction(Expression.SortDirection.DESC_NULLS_FIRST) .build(); - Expression.SortField sortField3 = + final Expression.SortField sortField3 = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 2)) .direction(Expression.SortDirection.ASC_NULLS_LAST) .build(); - Rel sort = + final Rel sort = Sort.builder().input(baseTable).addSortFields(sortField1, sortField2, sortField3).build(); verifyRoundTrip(sort); @@ -193,7 +193,7 @@ void sortWithMixedNullHandling() { @Test void sortAllDirections() { // Test all sort directions in single sort operation - Rel sort = + final Rel sort = Sort.builder() .input(baseTable) .addSortFields( @@ -225,21 +225,21 @@ void sortAllDirections() { @Test void nestedSort() { // Sort on top of another sort - Expression.SortField firstSort = + final Expression.SortField firstSort = Expression.SortField.builder() .expr(b.fieldReference(baseTable, 3)) .direction(Expression.SortDirection.ASC_NULLS_FIRST) .build(); - Rel firstSortRel = Sort.builder().input(baseTable).addSortFields(firstSort).build(); + final Rel firstSortRel = Sort.builder().input(baseTable).addSortFields(firstSort).build(); - Expression.SortField secondSort = + final Expression.SortField secondSort = Expression.SortField.builder() .expr(b.fieldReference(firstSortRel, 0)) .direction(Expression.SortDirection.DESC_NULLS_LAST) .build(); - Rel secondSortRel = Sort.builder().input(firstSortRel).addSortFields(secondSort).build(); + final Rel secondSortRel = Sort.builder().input(firstSortRel).addSortFields(secondSort).build(); verifyRoundTrip(secondSortRel); } diff --git a/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java b/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java index 5b271e52b..bb2309322 100644 --- a/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java +++ b/core/src/test/java/io/substrait/type/proto/TestTypeRoundtrip.java @@ -11,15 +11,15 @@ class TestTypeRoundtrip { - private ExtensionCollector lookup = new ExtensionCollector(); - private TypeProtoConverter typeProtoConverter = new TypeProtoConverter(lookup); + private final ExtensionCollector lookup = new ExtensionCollector(); + private final TypeProtoConverter typeProtoConverter = new TypeProtoConverter(lookup); - private ProtoTypeConverter protoTypeConverter = + private final ProtoTypeConverter protoTypeConverter = new ProtoTypeConverter(lookup, SimpleExtension.ExtensionCollection.builder().build()); @ParameterizedTest @ValueSource(booleans = {true, false}) - void roundtrip(boolean n) { + void roundtrip(final boolean n) { t(creator(n).BOOLEAN); t(creator(n).I8); t(creator(n).I16); @@ -53,12 +53,12 @@ void roundtrip(boolean n) { * * @param type */ - private void t(Type type) { - io.substrait.proto.Type converted = type.accept(typeProtoConverter); + private void t(final Type type) { + final io.substrait.proto.Type converted = type.accept(typeProtoConverter); assertEquals(type, protoTypeConverter.from(converted)); } - private TypeCreator creator(boolean nullable) { + private TypeCreator creator(final boolean nullable) { return nullable ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; } } diff --git a/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java index dd7c1b2fa..3523e9d56 100644 --- a/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/UpdateRelRoundtripTest.java @@ -18,11 +18,11 @@ class UpdateRelRoundtripTest extends TestBase { @Test void update() { - NamedStruct schema = + final NamedStruct schema = NamedStruct.of( Stream.of("column1", "column2").collect(Collectors.toList()), R.struct(R.I64, R.I64)); - List transformations = + final List transformations = Arrays.asList( NamedUpdate.TransformExpression.builder() .columnTarget(0) @@ -33,9 +33,9 @@ void update() { .transformation(fnAdd(2)) .build()); - Expression condition = ExpressionCreator.bool(false, true); + final Expression condition = ExpressionCreator.bool(false, true); - NamedUpdate command = + final NamedUpdate command = NamedUpdate.builder() .tableSchema(schema) .names(Stream.of("table").collect(Collectors.toList())) @@ -46,7 +46,7 @@ void update() { verifyRoundTrip(command); } - private Expression.ScalarFunctionInvocation fnAdd(int value) { + private Expression.ScalarFunctionInvocation fnAdd(final int value) { return defaultExtensionCollection.scalarFunctions().stream() .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() diff --git a/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java index 16e05e538..da4ccd005 100644 --- a/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/WriteRelRoundtripTest.java @@ -20,7 +20,7 @@ class WriteRelRoundtripTest extends TestBase { @Test void insert() { - NamedStruct schema = + final NamedStruct schema = NamedStruct.of( Stream.of("column1", "column2").collect(Collectors.toList()), R.struct(R.I64, R.I64)); @@ -37,7 +37,7 @@ void insert() { .filter(b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1))) .build(); - NamedWrite command = + final NamedWrite command = NamedWrite.builder() .input(virtTable) .tableSchema(schema) @@ -52,16 +52,16 @@ void insert() { @Test void append() { - ProtoRelConverter protoRelConverter = + final ProtoRelConverter protoRelConverter = new StringHolderHandlingProtoRelConverter(functionCollector, defaultExtensionCollection); - StringHolder detail = new StringHolder("DETAIL"); + final StringHolder detail = new StringHolder("DETAIL"); - NamedStruct schema = + final NamedStruct schema = NamedStruct.of( Stream.of("column1", "column2").collect(Collectors.toList()), R.struct(R.I64, R.I64)); - VirtualTableScan virtTable = + final VirtualTableScan virtTable = VirtualTableScan.builder() .initialSchema(schema) .addRows( @@ -69,7 +69,7 @@ void append() { false, ExpressionCreator.i64(false, 1), ExpressionCreator.i64(false, 2))) .build(); - ExtensionWrite command = + final ExtensionWrite command = ExtensionWrite.builder() .input(virtTable) .tableSchema(schema) @@ -79,8 +79,8 @@ void append() { .outputMode(ExtensionWrite.OutputMode.NO_OUTPUT) .build(); - io.substrait.proto.Rel protoRel = relProtoConverter.toProto(command); - Rel relReturned = protoRelConverter.from(protoRel); + final io.substrait.proto.Rel protoRel = relProtoConverter.toProto(command); + final Rel relReturned = protoRelConverter.from(protoRel); assertEquals(command, relReturned); } } diff --git a/core/src/test/java/io/substrait/utils/StringHolder.java b/core/src/test/java/io/substrait/utils/StringHolder.java index 4e82e82cb..5b1377a77 100644 --- a/core/src/test/java/io/substrait/utils/StringHolder.java +++ b/core/src/test/java/io/substrait/utils/StringHolder.java @@ -34,7 +34,7 @@ public class StringHolder private final String value; - public StringHolder(String value) { + public StringHolder(final String value) { this.value = value; } @@ -43,7 +43,7 @@ public static StringHolder fromProto(final Any any) { if (PROTO_TYPE_URL.equals(any.getTypeUrl())) { return new StringHolder(any.unpack(StringValue.class).getValue()); } - } catch (InvalidProtocolBufferException e) { + } catch (final InvalidProtocolBufferException e) { throw new IllegalStateException(e); } @@ -52,7 +52,7 @@ public static StringHolder fromProto(final Any any) { } @Override - public Any toProto(RelProtoConverter relProtoConverter) { + public Any toProto(final RelProtoConverter relProtoConverter) { return com.google.protobuf.Any.pack(com.google.protobuf.StringValue.of(this.value)); } @@ -62,12 +62,12 @@ public Type.Struct deriveRecordType() { } @Override - public Type.Struct deriveRecordType(Rel input) { + public Type.Struct deriveRecordType(final Rel input) { return TypeCreator.NULLABLE.struct(); } @Override - public Type.Struct deriveRecordType(List inputs) { + public Type.Struct deriveRecordType(final List inputs) { return TypeCreator.NULLABLE.struct(); } @@ -77,10 +77,10 @@ public NamedStruct deriveSchema() { } @Override - public boolean equals(Object o) { + public boolean equals(final Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - StringHolder that = (StringHolder) o; + final StringHolder that = (StringHolder) o; return Objects.equals(value, that.value); } diff --git a/core/src/test/java/io/substrait/utils/StringHolderHandlingProtoRelConverter.java b/core/src/test/java/io/substrait/utils/StringHolderHandlingProtoRelConverter.java index 44b05c22d..380f6dfa0 100644 --- a/core/src/test/java/io/substrait/utils/StringHolderHandlingProtoRelConverter.java +++ b/core/src/test/java/io/substrait/utils/StringHolderHandlingProtoRelConverter.java @@ -14,37 +14,37 @@ */ public class StringHolderHandlingProtoRelConverter extends ProtoRelConverter { public StringHolderHandlingProtoRelConverter( - ExtensionLookup lookup, SimpleExtension.ExtensionCollection extensions) { + final ExtensionLookup lookup, final SimpleExtension.ExtensionCollection extensions) { super(lookup, extensions, new StringHolderHandlingProtoExtensionConverter()); } @Override - protected Extension.LeafRelDetail detailFromExtensionLeafRel(Any any) { + protected Extension.LeafRelDetail detailFromExtensionLeafRel(final Any any) { return StringHolder.fromProto(any); } @Override - protected Extension.SingleRelDetail detailFromExtensionSingleRel(Any any) { + protected Extension.SingleRelDetail detailFromExtensionSingleRel(final Any any) { return StringHolder.fromProto(any); } @Override - protected Extension.MultiRelDetail detailFromExtensionMultiRel(Any any) { + protected Extension.MultiRelDetail detailFromExtensionMultiRel(final Any any) { return StringHolder.fromProto(any); } @Override - protected Extension.ExtensionTableDetail detailFromExtensionTable(Any any) { + protected Extension.ExtensionTableDetail detailFromExtensionTable(final Any any) { return StringHolder.fromProto(any); } @Override - protected Extension.WriteExtensionObject detailFromWriteExtensionObject(Any any) { + protected Extension.WriteExtensionObject detailFromWriteExtensionObject(final Any any) { return StringHolder.fromProto(any); } @Override - protected Extension.DdlExtensionObject detailFromDdlExtensionObject(Any any) { + protected Extension.DdlExtensionObject detailFromDdlExtensionObject(final Any any) { return StringHolder.fromProto(any); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/App.java b/examples/substrait-spark/src/main/java/io/substrait/examples/App.java index d2ca1edf2..5712c4ab8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/App.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/App.java @@ -27,10 +27,10 @@ public static void main(String args[]) { if (args.length == 0) { args = new String[] {"SparkDataset"}; } - String exampleClass = args[0]; + final String exampleClass = args[0]; - Class clz = Class.forName(App.class.getPackageName() + "." + exampleClass); - Action action = (Action) clz.getDeclaredConstructor().newInstance(); + final Class clz = Class.forName(App.class.getPackageName() + "." + exampleClass); + final Action action = (Action) clz.getDeclaredConstructor().newInstance(); if (args.length == 2) { action.run(args[1]); @@ -38,7 +38,7 @@ public static void main(String args[]) { action.run(null); } - } catch (Exception e) { + } catch (final Exception e) { e.printStackTrace(); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java index 26c15274f..0dc5d259d 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkConsumeSubstrait.java @@ -17,22 +17,22 @@ public class SparkConsumeSubstrait implements App.Action { @Override - public void run(String arg) { + public void run(final String arg) { // Connect to a local in-process Spark instance try (SparkSession spark = SparkHelper.connectLocalSpark()) { System.out.println("Reading from " + arg); - byte[] buffer = Files.readAllBytes(Paths.get(ROOT_DIR, arg)); + final byte[] buffer = Files.readAllBytes(Paths.get(ROOT_DIR, arg)); - io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(buffer); - ProtoPlanConverter protoToPlan = new ProtoPlanConverter(); - Plan plan = protoToPlan.from(proto); + final io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(buffer); + final ProtoPlanConverter protoToPlan = new ProtoPlanConverter(); + final Plan plan = protoToPlan.from(proto); SubstraitStringify.explain(plan).forEach(System.out::println); - ToLogicalPlan substraitConverter = new ToLogicalPlan(spark); - LogicalPlan sparkPlan = substraitConverter.convert(plan); + final ToLogicalPlan substraitConverter = new ToLogicalPlan(spark); + final LogicalPlan sparkPlan = substraitConverter.convert(plan); System.out.println(sparkPlan); diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java index 81de54b0b..c54528d75 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkDataset.java @@ -24,12 +24,12 @@ public void run(String arg) { // Connect to a local in-process Spark instance try (SparkSession spark = SparkHelper.connectLocalSpark()) { - Dataset dsVehicles; - Dataset dsTests; + final Dataset dsVehicles; + final Dataset dsTests; // load from CSV files - String vehiclesFile = Paths.get(ROOT_DIR, VEHICLES_CSV).toString(); - String testsFile = Paths.get(ROOT_DIR, TESTS_CSV).toString(); + final String vehiclesFile = Paths.get(ROOT_DIR, VEHICLES_CSV).toString(); + final String testsFile = Paths.get(ROOT_DIR, TESTS_CSV).toString(); System.out.println("Reading " + vehiclesFile); System.out.println("Reading " + testsFile); @@ -51,7 +51,7 @@ public void run(String arg) { joinedDs = joinedDs.orderBy(joinedDs.col("count")); joinedDs.show(); - LogicalPlan plan = joinedDs.queryExecution().optimizedPlan(); + final LogicalPlan plan = joinedDs.queryExecution().optimizedPlan(); System.out.println(plan); createSubstrait(plan); @@ -67,14 +67,14 @@ public void run(String arg) { * * @param enginePlan logical plan */ - public void createSubstrait(LogicalPlan enginePlan) { - ToSubstraitRel toSubstrait = new ToSubstraitRel(); - io.substrait.plan.Plan plan = toSubstrait.convert(enginePlan); + public void createSubstrait(final LogicalPlan enginePlan) { + final ToSubstraitRel toSubstrait = new ToSubstraitRel(); + final io.substrait.plan.Plan plan = toSubstrait.convert(enginePlan); SubstraitStringify.explain(plan).forEach(System.out::println); - PlanProtoConverter planToProto = new PlanProtoConverter(); - byte[] buffer = planToProto.toProto(plan).toByteArray(); + final PlanProtoConverter planToProto = new PlanProtoConverter(); + final byte[] buffer = planToProto.toProto(plan).toByteArray(); try { Files.write(Paths.get(ROOT_DIR, "spark_dataset_substrait.plan"), buffer); System.out.println("File written to " + Paths.get(ROOT_DIR, "spark_sql_substrait.plan")); diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java index 270b760ca..3fb32a891 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkHelper.java @@ -29,7 +29,7 @@ private SparkHelper() {} */ public static SparkSession connectLocalSpark() { - SparkSession spark = SparkSession.builder().enableHiveSupport().getOrCreate(); + final SparkSession spark = SparkSession.builder().enableHiveSupport().getOrCreate(); spark.sparkContext().setLogLevel("ERROR"); diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java index ddb544f00..b4b80f210 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/SparkSQL.java @@ -28,8 +28,8 @@ public void run(String arg) { spark.catalog().listDatabases().show(); // load from CSV files - String vehiclesFile = Paths.get(ROOT_DIR, VEHICLES_CSV).toString(); - String testsFile = Paths.get(ROOT_DIR, TESTS_CSV).toString(); + final String vehiclesFile = Paths.get(ROOT_DIR, VEHICLES_CSV).toString(); + final String testsFile = Paths.get(ROOT_DIR, TESTS_CSV).toString(); System.out.println("Reading " + vehiclesFile); System.out.println("Reading " + testsFile); @@ -47,7 +47,7 @@ public void run(String arg) { .csv(testsFile) .createOrReplaceTempView(TESTS_TABLE); - String sqlQuery = + final String sqlQuery = "SELECT vehicles.colour, count(*) as colourcount" + " FROM vehicles" + " INNER JOIN tests ON vehicles.vehicle_id=tests.vehicle_id" @@ -55,13 +55,13 @@ public void run(String arg) { + " GROUP BY vehicles.colour" + " ORDER BY count(*)"; - Dataset result = spark.sql(sqlQuery); + final Dataset result = spark.sql(sqlQuery); result.show(); - LogicalPlan logical = result.logicalPlan(); + final LogicalPlan logical = result.logicalPlan(); System.out.println(logical); - LogicalPlan optimised = result.queryExecution().optimizedPlan(); + final LogicalPlan optimised = result.queryExecution().optimizedPlan(); System.out.println(optimised); createSubstrait(optimised); @@ -76,14 +76,14 @@ public void run(String arg) { * * @param enginePlan Spark Local PLan */ - public void createSubstrait(LogicalPlan enginePlan) { - ToSubstraitRel toSubstrait = new ToSubstraitRel(); - io.substrait.plan.Plan plan = toSubstrait.convert(enginePlan); + public void createSubstrait(final LogicalPlan enginePlan) { + final ToSubstraitRel toSubstrait = new ToSubstraitRel(); + final io.substrait.plan.Plan plan = toSubstrait.convert(enginePlan); SubstraitStringify.explain(plan).forEach(System.out::println); - PlanProtoConverter planToProto = new PlanProtoConverter(); - byte[] buffer = planToProto.toProto(plan).toByteArray(); + final PlanProtoConverter planToProto = new PlanProtoConverter(); + final byte[] buffer = planToProto.toProto(plan).toByteArray(); try { Files.write(Paths.get(ROOT_DIR, "spark_sql_substrait.plan"), buffer); System.out.println("File written to " + Paths.get(ROOT_DIR, "spark_sql_substrait.plan")); diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 71de9a7d5..ee296905c 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -51,171 +51,191 @@ public class ExpressionStringify extends ParentStringify implements ExpressionVisitor { - public ExpressionStringify(int indent) { + public ExpressionStringify(final int indent) { super(indent); } @Override - public String visit(NullLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final NullLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(BoolLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final BoolLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(I8Literal expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final I8Literal expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(I16Literal expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final I16Literal expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(I32Literal expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final I32Literal expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(I64Literal expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final I64Literal expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(FP32Literal expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final FP32Literal expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(FP64Literal expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final FP64Literal expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(StrLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final StrLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(BinaryLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final BinaryLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(TimeLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final TimeLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(DateLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final DateLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(TimestampLiteral expr, EmptyVisitationContext context) + public String visit(final TimestampLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(TimestampTZLiteral expr, EmptyVisitationContext context) + public String visit(final TimestampTZLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(IntervalYearLiteral expr, EmptyVisitationContext context) + public String visit(final IntervalYearLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(IntervalDayLiteral expr, EmptyVisitationContext context) + public String visit(final IntervalDayLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(UUIDLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final UUIDLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(FixedCharLiteral expr, EmptyVisitationContext context) + public String visit(final FixedCharLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(VarCharLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final VarCharLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(FixedBinaryLiteral expr, EmptyVisitationContext context) + public String visit(final FixedBinaryLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(DecimalLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final DecimalLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(MapLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final MapLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(ListLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final ListLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(EmptyListLiteral expr, EmptyVisitationContext context) + public String visit(final EmptyListLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(StructLiteral expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final StructLiteral expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(UserDefinedLiteral expr, EmptyVisitationContext context) + public String visit(final UserDefinedLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(Switch expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final Switch expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(IfThen expr, EmptyVisitationContext context) throws RuntimeException { + public String visit(final IfThen expr, final EmptyVisitationContext context) + throws RuntimeException { return ""; } @Override - public String visit(ScalarFunctionInvocation expr, EmptyVisitationContext context) + public String visit(final ScalarFunctionInvocation expr, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = new StringBuilder(""); + final StringBuilder sb = new StringBuilder(""); sb.append(expr.declaration()); // sb.append(" ("); - List args = expr.arguments(); + final List args = expr.arguments(); for (int i = 0; i < args.size(); i++) { - FunctionArg arg = args.get(i); + final FunctionArg arg = args.get(i); sb.append(getContinuationIndentString()); sb.append("arg" + i + " = "); - FunctionArgStringify funcArgVisitor = new FunctionArgStringify(indent); + final FunctionArgStringify funcArgVisitor = new FunctionArgStringify(indent); sb.append(arg.accept(expr.declaration(), i, funcArgVisitor, context)); sb.append(" "); @@ -224,16 +244,17 @@ public String visit(ScalarFunctionInvocation expr, EmptyVisitationContext contex } @Override - public String visit(WindowFunctionInvocation expr, EmptyVisitationContext context) + public String visit(final WindowFunctionInvocation expr, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = new StringBuilder("WindowFunctionInvocation#"); + final StringBuilder sb = new StringBuilder("WindowFunctionInvocation#"); return sb.toString(); } @Override - public String visit(Cast expr, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = new StringBuilder(""; } @Override - public String visit(PrecisionTimestampTZLiteral expr, EmptyVisitationContext context) + public String visit(final PrecisionTimestampTZLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } @Override - public String visit(IntervalCompoundLiteral expr, EmptyVisitationContext context) + public String visit(final IntervalCompoundLiteral expr, final EmptyVisitationContext context) throws RuntimeException { return ""; } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/FunctionArgStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/FunctionArgStringify.java index 372b7e79f..98bb7ebe4 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/FunctionArgStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/FunctionArgStringify.java @@ -11,24 +11,30 @@ public class FunctionArgStringify extends ParentStringify implements FuncArgVisitor { - public FunctionArgStringify(int indent) { + public FunctionArgStringify(final int indent) { super(indent); } @Override - public String visitExpr(Function fnDef, int argIdx, Expression e, EmptyVisitationContext context) + public String visitExpr( + final Function fnDef, + final int argIdx, + final Expression e, + final EmptyVisitationContext context) throws RuntimeException { return e.accept(new ExpressionStringify(indent + 1), context); } @Override - public String visitType(Function fnDef, int argIdx, Type t, EmptyVisitationContext context) + public String visitType( + final Function fnDef, final int argIdx, final Type t, final EmptyVisitationContext context) throws RuntimeException { return t.accept(new TypeStringify(indent)); } @Override - public String visitEnumArg(Function fnDef, int argIdx, EnumArg e, EmptyVisitationContext context) + public String visitEnumArg( + final Function fnDef, final int argIdx, final EnumArg e, final EmptyVisitationContext context) throws RuntimeException { return e.toString(); } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ParentStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ParentStringify.java index 62535bedc..c74666de4 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ParentStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ParentStringify.java @@ -15,13 +15,13 @@ public class ParentStringify { * * @param indent number of indentes */ - public ParentStringify(int indent) { + public ParentStringify(final int indent) { this.indent = indent; } StringBuilder getIndent() { - StringBuilder sb = new StringBuilder(); + final StringBuilder sb = new StringBuilder(); if (indent != 0) { sb.append("\n"); } @@ -33,7 +33,7 @@ StringBuilder getIndent() { StringBuilder getIndentString() { - StringBuilder sb = new StringBuilder(); + final StringBuilder sb = new StringBuilder(); sb.append(indentChar.repeat(this.indent * this.indentSize)); sb.append("+- "); return sb; @@ -41,7 +41,7 @@ StringBuilder getIndentString() { StringBuilder getContinuationIndentString() { - StringBuilder sb = new StringBuilder(); + final StringBuilder sb = new StringBuilder(); if (indent != 0) { sb.append("\n"); } @@ -50,7 +50,7 @@ StringBuilder getContinuationIndentString() { return sb; } - protected String getOutdent(StringBuilder sb) { + protected String getOutdent(final StringBuilder sb) { indent--; return (sb).toString(); } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index a94d47fd4..5bfcd5548 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -65,7 +65,7 @@ public class SubstraitStringify extends ParentStringify implements RelVisitor { - private boolean showRemap = false; + private final boolean showRemap = false; public SubstraitStringify() { super(0); @@ -77,14 +77,14 @@ public SubstraitStringify() { * @param plan Subsrait plan * @return List of strings; typically these would then be logged or sent to stdout */ - public static List explain(io.substrait.plan.Plan plan) { - ArrayList explanations = new ArrayList(); + public static List explain(final io.substrait.plan.Plan plan) { + final ArrayList explanations = new ArrayList(); explanations.add(""); plan.getRoots() .forEach( root -> { - Rel rel = root.getInput(); + final Rel rel = root.getInput(); explanations.add("Root:: " + rel.getClass().getSimpleName() + " " + root.getNames()); explanations.addAll(explain(rel)); @@ -99,26 +99,26 @@ public static List explain(io.substrait.plan.Plan plan) { * @param plan Subsrait relation * @return List of strings; typically these would then be logged or sent to stdout */ - public static List explain(io.substrait.relation.Rel rel) { - SubstraitStringify s = new SubstraitStringify(); + public static List explain(final io.substrait.relation.Rel rel) { + final SubstraitStringify s = new SubstraitStringify(); - List explanation = new ArrayList(); + final List explanation = new ArrayList(); explanation.add(""); explanation.addAll(Arrays.asList(rel.accept(s, EmptyVisitationContext.INSTANCE).split("\n"))); return explanation; } - private List fieldList(List fields) { + private List fieldList(final List fields) { return fields.stream().map(t -> t.accept(new TypeStringify(0))).collect(Collectors.toList()); } - private String getRemap(Rel rel) { + private String getRemap(final Rel rel) { if (!showRemap) { return ""; } - int fieldCount = rel.getRecordType().fields().size(); - Optional remap = rel.getRemap(); - List recordType = fieldList(rel.getRecordType().fields()); + final int fieldCount = rel.getRecordType().fields().size(); + final Optional remap = rel.getRemap(); + final List recordType = fieldList(rel.getRecordType().fields()); if (remap.isPresent()) { return "/Remapping fields (" @@ -134,8 +134,9 @@ private String getRemap(Rel rel) { } @Override - public String visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("Aggregate:: ").append(getRemap(aggregate)); + public String visit(final Aggregate aggregate, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("Aggregate:: ").append(getRemap(aggregate)); aggregate .getGroupings() .forEach( @@ -158,22 +159,25 @@ public String visit(Aggregate aggregate, EmptyVisitationContext context) throws } @Override - public String visit(EmptyScan emptyScan, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = new StringBuilder("EmptyScan:: ").append(getRemap(emptyScan)); + public String visit(final EmptyScan emptyScan, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = new StringBuilder("EmptyScan:: ").append(getRemap(emptyScan)); // sb.append(emptyScan.accept(this)); return getOutdent(sb); } @Override - public String visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = new StringBuilder("Fetch:: "); + public String visit(final Fetch fetch, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = new StringBuilder("Fetch:: "); // sb.append(fetch.accept(this)); return getOutdent(sb); } @Override - public String visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("Filter:: ").append(getRemap(filter)); + public String visit(final Filter filter, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("Filter:: ").append(getRemap(filter)); // .append("{ "); sb.append( filter.getCondition().accept(new ExpressionStringify(indent), context)) /* .append(")") */; @@ -188,9 +192,10 @@ public String visit(Filter filter, EmptyVisitationContext context) throws Runtim } @Override - public String visit(Join join, EmptyVisitationContext context) throws RuntimeException { + public String visit(final Join join, final EmptyVisitationContext context) + throws RuntimeException { - StringBuilder sb = + final StringBuilder sb = getIndent().append("Join:: ").append(join.getJoinType()).append(" ").append(getRemap(join)); if (join.getCondition().isPresent()) { @@ -204,15 +209,16 @@ public String visit(Join join, EmptyVisitationContext context) throws RuntimeExc } @Override - public String visit(Set set, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("Set:: "); + public String visit(final Set set, final EmptyVisitationContext context) throws RuntimeException { + final StringBuilder sb = getIndent().append("Set:: "); return getOutdent(sb); } @Override - public String visit(NamedScan namedScan, EmptyVisitationContext context) throws RuntimeException { + public String visit(final NamedScan namedScan, final EmptyVisitationContext context) + throws RuntimeException { - StringBuilder sb = getIndent().append("NamedScan:: ").append(getRemap(namedScan)); + final StringBuilder sb = getIndent().append("NamedScan:: ").append(getRemap(namedScan)); namedScan .getInputs() .forEach( @@ -226,11 +232,11 @@ public String visit(NamedScan namedScan, EmptyVisitationContext context) throws return getOutdent(sb); } - private String namedStruct(NamedStruct struct) { - StringBuilder sb = new StringBuilder(); + private String namedStruct(final NamedStruct struct) { + final StringBuilder sb = new StringBuilder(); - List names = struct.names(); - List types = fieldList(struct.struct().fields()); + final List names = struct.names(); + final List types = fieldList(struct.struct().fields()); for (int x = 0; x < names.size(); x++) { if (x != 0) { @@ -243,11 +249,11 @@ private String namedStruct(NamedStruct struct) { } @Override - public String visit(LocalFiles localFiles, EmptyVisitationContext context) + public String visit(final LocalFiles localFiles, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("LocalFiles:: "); + final StringBuilder sb = getIndent().append("LocalFiles:: "); - for (FileOrFiles i : localFiles.getItems()) { + for (final FileOrFiles i : localFiles.getItems()) { sb.append(getContinuationIndentString()); String fileFormat = ""; if (i.getFileFormat().isPresent()) { @@ -264,12 +270,13 @@ public String visit(LocalFiles localFiles, EmptyVisitationContext context) } @Override - public String visit(Project project, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("Project:: ").append(getRemap(project)); + public String visit(final Project project, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("Project:: ").append(getRemap(project)); sb.append(fieldList(project.deriveRecordType().fields())); - List inputs = project.getInputs(); + final List inputs = project.getInputs(); inputs.forEach( i -> { sb.append(i.accept(this, context)); @@ -278,15 +285,16 @@ public String visit(Project project, EmptyVisitationContext context) throws Runt } @Override - public String visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("Sort:: ").append(getRemap(sort)); + public String visit(final Sort sort, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("Sort:: ").append(getRemap(sort)); sort.getSortFields() .forEach( sf -> { - ExpressionStringify expr = new ExpressionStringify(indent); + final ExpressionStringify expr = new ExpressionStringify(indent); sb.append(sf.expr().accept(expr, context)).append(" ").append(sf.direction()); }); - List inputs = sort.getInputs(); + final List inputs = sort.getInputs(); inputs.forEach( i -> { sb.append(i.accept(this, context)); @@ -295,142 +303,151 @@ public String visit(Sort sort, EmptyVisitationContext context) throws RuntimeExc } @Override - public String visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("Cross:: "); + public String visit(final Cross cross, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("Cross:: "); return getOutdent(sb); } @Override - public String visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) + public String visit(final VirtualTableScan virtualTableScan, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("VirtualTableScan:: "); + final StringBuilder sb = getIndent().append("VirtualTableScan:: "); return getOutdent(sb); } @Override - public String visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) + public String visit(final ExtensionLeaf extensionLeaf, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("extensionLeaf:: "); + final StringBuilder sb = getIndent().append("extensionLeaf:: "); return getOutdent(sb); } @Override - public String visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + public String visit(final ExtensionSingle extensionSingle, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("extensionSingle:: "); + final StringBuilder sb = getIndent().append("extensionSingle:: "); return getOutdent(sb); } @Override - public String visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + public String visit(final ExtensionMulti extensionMulti, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("extensionMulti:: "); + final StringBuilder sb = getIndent().append("extensionMulti:: "); return getOutdent(sb); } @Override - public String visit(ExtensionTable extensionTable, EmptyVisitationContext context) + public String visit(final ExtensionTable extensionTable, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("extensionTable:: "); + final StringBuilder sb = getIndent().append("extensionTable:: "); return getOutdent(sb); } @Override - public String visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("hashJoin:: "); + public String visit(final HashJoin hashJoin, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("hashJoin:: "); return getOutdent(sb); } @Override - public String visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("mergeJoin:: "); + public String visit(final MergeJoin mergeJoin, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("mergeJoin:: "); return getOutdent(sb); } @Override - public String visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) + public String visit(final NestedLoopJoin nestedLoopJoin, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("nestedLoopJoin:: "); + final StringBuilder sb = getIndent().append("nestedLoopJoin:: "); return getOutdent(sb); } @Override public String visit( - ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + final ConsistentPartitionWindow consistentPartitionWindow, + final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("consistentPartitionWindow:: "); + final StringBuilder sb = getIndent().append("consistentPartitionWindow:: "); return getOutdent(sb); } @Override - public String visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("expand:: "); + public String visit(final Expand expand, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("expand:: "); return getOutdent(sb); } @Override - public String visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("namedWrite:: "); + public String visit(final NamedWrite write, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("namedWrite:: "); return getOutdent(sb); } @Override - public String visit(ExtensionWrite write, EmptyVisitationContext context) + public String visit(final ExtensionWrite write, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("extensionWrite:: "); + final StringBuilder sb = getIndent().append("extensionWrite:: "); return getOutdent(sb); } @Override - public String visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("namedDdl:: "); + public String visit(final NamedDdl ddl, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("namedDdl:: "); return getOutdent(sb); } @Override - public String visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("extensionDdl:: "); + public String visit(final ExtensionDdl ddl, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("extensionDdl:: "); return getOutdent(sb); } @Override - public String visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("namedUpdate:: "); + public String visit(final NamedUpdate update, final EmptyVisitationContext context) + throws RuntimeException { + final StringBuilder sb = getIndent().append("namedUpdate:: "); return getOutdent(sb); } @Override - public String visit(ScatterExchange exchange, EmptyVisitationContext context) + public String visit(final ScatterExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("scatterExchange:: "); + final StringBuilder sb = getIndent().append("scatterExchange:: "); return getOutdent(sb); } @Override - public String visit(SingleBucketExchange exchange, EmptyVisitationContext context) + public String visit(final SingleBucketExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("singleBucketExchange:: "); + final StringBuilder sb = getIndent().append("singleBucketExchange:: "); return getOutdent(sb); } @Override - public String visit(MultiBucketExchange exchange, EmptyVisitationContext context) + public String visit(final MultiBucketExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("multiBucketExchange:: "); + final StringBuilder sb = getIndent().append("multiBucketExchange:: "); return getOutdent(sb); } @Override - public String visit(RoundRobinExchange exchange, EmptyVisitationContext context) + public String visit(final RoundRobinExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("roundRobinExchange:: "); + final StringBuilder sb = getIndent().append("roundRobinExchange:: "); return getOutdent(sb); } @Override - public String visit(BroadcastExchange exchange, EmptyVisitationContext context) + public String visit(final BroadcastExchange exchange, final EmptyVisitationContext context) throws RuntimeException { - StringBuilder sb = getIndent().append("broadcastExchange:: "); + final StringBuilder sb = getIndent().append("broadcastExchange:: "); return getOutdent(sb); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 0e13c3e2e..a86877ec4 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -33,125 +33,125 @@ public class TypeStringify extends ParentStringify implements TypeVisitor { - protected TypeStringify(int indent) { + protected TypeStringify(final int indent) { super(indent); } @Override - public String visit(I64 type) throws RuntimeException { + public String visit(final I64 type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Bool type) throws RuntimeException { + public String visit(final Bool type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(I8 type) throws RuntimeException { + public String visit(final I8 type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(I16 type) throws RuntimeException { + public String visit(final I16 type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(I32 type) throws RuntimeException { + public String visit(final I32 type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(FP32 type) throws RuntimeException { + public String visit(final FP32 type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(FP64 type) throws RuntimeException { + public String visit(final FP64 type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Str type) throws RuntimeException { + public String visit(final Str type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Binary type) throws RuntimeException { + public String visit(final Binary type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Date type) throws RuntimeException { + public String visit(final Date type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Time type) throws RuntimeException { + public String visit(final Time type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override @Deprecated - public String visit(TimestampTZ type) throws RuntimeException { + public String visit(final TimestampTZ type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override @Deprecated - public String visit(Timestamp type) throws RuntimeException { + public String visit(final Timestamp type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Type.PrecisionTimestamp type) throws RuntimeException { + public String visit(final Type.PrecisionTimestamp type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Type.PrecisionTimestampTZ type) throws RuntimeException { + public String visit(final Type.PrecisionTimestampTZ type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(IntervalYear type) throws RuntimeException { + public String visit(final IntervalYear type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(IntervalDay type) throws RuntimeException { + public String visit(final IntervalDay type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(UUID type) throws RuntimeException { + public String visit(final UUID type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(FixedChar type) throws RuntimeException { + public String visit(final FixedChar type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(VarChar type) throws RuntimeException { + public String visit(final VarChar type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(FixedBinary type) throws RuntimeException { + public String visit(final FixedBinary type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Decimal type) throws RuntimeException { + public String visit(final Decimal type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Struct type) throws RuntimeException { - StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); + public String visit(final Struct type) throws RuntimeException { + final StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); type.fields() .forEach( f -> { @@ -161,27 +161,27 @@ public String visit(Struct type) throws RuntimeException { } @Override - public String visit(ListType type) throws RuntimeException { + public String visit(final ListType type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(Map type) throws RuntimeException { + public String visit(final Map type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(UserDefined type) throws RuntimeException { + public String visit(final UserDefined type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(PrecisionTime type) throws RuntimeException { + public String visit(final PrecisionTime type) throws RuntimeException { return type.getClass().getSimpleName(); } @Override - public String visit(IntervalCompound type) throws RuntimeException { + public String visit(final IntervalCompound type) throws RuntimeException { return type.getClass().getSimpleName(); } } diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index 35d5c6ed2..ffe174960 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -61,10 +61,10 @@ enum OutputFormat { description = "Calcite's casing policy for unquoted identifiers: ${COMPLETION-CANDIDATES}") private Casing unquotedCasing = Casing.TO_UPPER; - public static void main(String... args) { - CommandLine commandLine = new CommandLine(new IsthmusEntryPoint()); + public static void main(final String... args) { + final CommandLine commandLine = new CommandLine(new IsthmusEntryPoint()); commandLine.setCaseInsensitiveEnumValuesAllowed(true); - CommandLine.ParseResult parseResult = commandLine.parseArgs(args); + final CommandLine.ParseResult parseResult = commandLine.parseArgs(args); if (parseResult.originalArgs().isEmpty()) { // If no arguments print usage help commandLine.usage(System.out); System.exit(0); @@ -77,31 +77,32 @@ public static void main(String... args) { commandLine.printVersionHelp(System.out); System.exit(0); } - int exitCode = commandLine.execute(args); + final int exitCode = commandLine.execute(args); System.exit(exitCode); } @Override public Integer call() throws Exception { - FeatureBoard featureBoard = buildFeatureBoard(); + final FeatureBoard featureBoard = buildFeatureBoard(); // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpressions != null) { - SqlExpressionToSubstrait converter = + final SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(featureBoard, DefaultExtensionCatalog.DEFAULT_COLLECTION); - ExtendedExpression extendedExpression = converter.convert(sqlExpressions, createStatements); + final ExtendedExpression extendedExpression = + converter.convert(sqlExpressions, createStatements); printMessage(extendedExpression); } else { // by default Isthmus image are parsing SQL Query - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); - Prepare.CatalogReader catalog = + final SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + final Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog( createStatements.toArray(String[]::new)); - Plan plan = new PlanProtoConverter().toProto(converter.convert(sql, catalog)); + final Plan plan = new PlanProtoConverter().toProto(converter.convert(sql, catalog)); printMessage(plan); } return 0; } - private void printMessage(Message message) throws IOException { + private void printMessage(final Message message) throws IOException { if (outputFormat == OutputFormat.PROTOJSON) { System.out.println(JsonFormat.printer().includingDefaultValueFields().print(message)); } else if (outputFormat == OutputFormat.PROTOTEXT) { diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java index 266a36ff0..5f6826c05 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/RegisterAtRuntime.java @@ -121,7 +121,7 @@ public void beforeAnalysis(BeforeAnalysisAccess access) { } } - private static void register(Class c) { + private static void register(final Class c) { RuntimeReflection.register(c); RuntimeReflection.register(c.getDeclaredConstructors()); RuntimeReflection.register(c.getDeclaredFields()); @@ -134,7 +134,7 @@ private static void register(Class c) { private static final class PackageScanner implements AutoCloseable { private final ScanResult scan; - PackageScanner(String... packageNames) { + PackageScanner(final String... packageNames) { scan = new ClassGraph() .enableAllInfo() @@ -144,16 +144,16 @@ private static final class PackageScanner implements AutoCloseable { .scan(); } - void registerByAnnotation(Class annotation) { + void registerByAnnotation(final Class annotation) { scan.getClassesWithAnnotation(annotation).loadClasses().forEach(this::registerByParent); } - void registerByParent(Class c) { + void registerByParent(final Class c) { register(c); getSubTypes(c).loadClasses().forEach(RegisterAtRuntime::register); } - private ClassInfoList getSubTypes(Class c) { + private ClassInfoList getSubTypes(final Class c) { if (c.isInterface()) { return scan.getClassesImplementing(c); } diff --git a/isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java b/isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java index 1a4fe42c8..e2f3bdc08 100644 --- a/isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java +++ b/isthmus-cli/src/test/java/io/substrait/isthmus/cli/IsthmusEntryPointTest.java @@ -9,17 +9,17 @@ class IsthmusEntryPointTest { @Test void canProcessQuery() { - IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); - CommandLine cli = new CommandLine(isthmusEntryPoint); - int statusCode = cli.execute("SELECT 1;"); + final IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); + final CommandLine cli = new CommandLine(isthmusEntryPoint); + final int statusCode = cli.execute("SELECT 1;"); assertEquals(0, statusCode); } @Test void canProcessQueryWithCreates() { - IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); - CommandLine cli = new CommandLine(isthmusEntryPoint); - int statusCode = cli.execute("SELECT * FROM foo", "--create", "CREATE TABLE foo(id INT)"); + final IsthmusEntryPoint isthmusEntryPoint = new IsthmusEntryPoint(); + final CommandLine cli = new CommandLine(isthmusEntryPoint); + final int statusCode = cli.execute("SELECT * FROM foo", "--create", "CREATE TABLE foo(id INT)"); assertEquals(0, statusCode); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 6cba80781..b2df1f05c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -37,9 +37,9 @@ public class AggregateFunctions { * @return an optional containing the Substrait equivalent of the given {@code aggFunction} if * conversion was needed, empty otherwise. */ - public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { + public static Optional toSubstraitAggVariant(final SqlAggFunction aggFunction) { if (aggFunction instanceof SqlMinMaxAggFunction) { - SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction; + final SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction; return Optional.of( fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); } else if (aggFunction instanceof SqlAvgAggFunction) { @@ -55,12 +55,12 @@ public static Optional toSubstraitAggVariant(SqlAggFunction aggF /** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */ private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction { - public SubstraitSqlMinMaxAggFunction(SqlKind kind) { + public SubstraitSqlMinMaxAggFunction(final SqlKind kind) { super(kind); } @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + public RelDataType inferReturnType(final SqlOperatorBinding opBinding) { return ReturnTypes.ARG0_FORCE_NULLABLE.inferReturnType(opBinding); } } @@ -74,19 +74,19 @@ public SubstraitSumAggFunction() { } @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + public RelDataType inferReturnType(final SqlOperatorBinding opBinding) { return ReturnTypes.ARG0_FORCE_NULLABLE.inferReturnType(opBinding); } } /** Extension of {@link SqlAvgAggFunction} that ALWAYS infers a nullable return type */ private static class SubstraitAvgAggFunction extends SqlAvgAggFunction { - public SubstraitAvgAggFunction(SqlKind kind) { + public SubstraitAvgAggFunction(final SqlKind kind) { super(kind); } @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + public RelDataType inferReturnType(final SqlOperatorBinding opBinding) { return ReturnTypes.ARG0_FORCE_NULLABLE.inferReturnType(opBinding); } } @@ -109,7 +109,7 @@ public String getName() { } @Override - public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + public RelDataType inferReturnType(final SqlOperatorBinding opBinding) { return ReturnTypes.BIGINT.inferReturnType(opBinding); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java index eeb645175..2208c4919 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java +++ b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java @@ -29,11 +29,11 @@ public OuterReferenceResolver() { fieldAccessDepthMap = new IdentityHashMap<>(); } - public int getStepsOut(RexFieldAccess fieldAccess) { + public int getStepsOut(final RexFieldAccess fieldAccess) { return fieldAccessDepthMap.get(fieldAccess); } - public RelNode apply(RelNode r) { + public RelNode apply(final RelNode r) { return reverseAccept(r); } @@ -42,8 +42,8 @@ public Map getFieldAccessDepthMap() { } @Override - public RelNode visit(Filter filter) throws RuntimeException { - for (CorrelationId id : filter.getVariablesSet()) { + public RelNode visit(final Filter filter) throws RuntimeException { + for (final CorrelationId id : filter.getVariablesSet()) { nestedDepth.putIfAbsent(id, 0); } filter.getCondition().accept(rexVisitor); @@ -51,8 +51,8 @@ public RelNode visit(Filter filter) throws RuntimeException { } @Override - public RelNode visit(Correlate correlate) throws RuntimeException { - for (CorrelationId id : correlate.getVariablesSet()) { + public RelNode visit(final Correlate correlate) throws RuntimeException { + for (final CorrelationId id : correlate.getVariablesSet()) { nestedDepth.putIfAbsent(id, 0); } @@ -71,20 +71,20 @@ public RelNode visit(Correlate correlate) throws RuntimeException { } @Override - public RelNode visitOther(RelNode other) throws RuntimeException { - for (RelNode child : other.getInputs()) { + public RelNode visitOther(final RelNode other) throws RuntimeException { + for (final RelNode child : other.getInputs()) { apply(child); } return other; } @Override - public RelNode visit(Project project) throws RuntimeException { - for (CorrelationId id : project.getVariablesSet()) { + public RelNode visit(final Project project) throws RuntimeException { + for (final CorrelationId id : project.getVariablesSet()) { nestedDepth.putIfAbsent(id, 0); } - for (RexSubQuery subQuery : SubQueryCollector.collect(project)) { + for (final RexSubQuery subQuery : SubQueryCollector.collect(project)) { subQuery.accept(rexVisitor); } @@ -94,12 +94,12 @@ public RelNode visit(Project project) throws RuntimeException { private static class RexVisitor extends RexShuttle { final OuterReferenceResolver referenceResolver; - RexVisitor(OuterReferenceResolver referenceResolver) { + RexVisitor(final OuterReferenceResolver referenceResolver) { this.referenceResolver = referenceResolver; } @Override - public RexNode visitSubQuery(RexSubQuery subQuery) { + public RexNode visitSubQuery(final RexSubQuery subQuery) { referenceResolver.nestedDepth.replaceAll((k, v) -> v + 1); referenceResolver.apply(subQuery.rel); // look inside sub-queries @@ -109,9 +109,9 @@ public RexNode visitSubQuery(RexSubQuery subQuery) { } @Override - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + public RexNode visitFieldAccess(final RexFieldAccess fieldAccess) { if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) { - CorrelationId id = ((RexCorrelVariable) fieldAccess.getReferenceExpr()).id; + final CorrelationId id = ((RexCorrelVariable) fieldAccess.getReferenceExpr()).id; if (referenceResolver.nestedDepth.containsKey(id)) { referenceResolver.fieldAccessDepthMap.put( fieldAccess, referenceResolver.nestedDepth.get(id)); diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java index f2419ab01..d8c52f847 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java @@ -30,7 +30,7 @@ public class PreCalciteAggregateValidator { * @param aggregate * @return */ - public static boolean isValidCalciteAggregate(Aggregate aggregate) { + public static boolean isValidCalciteAggregate(final Aggregate aggregate) { return aggregate.getMeasures().stream() .allMatch(PreCalciteAggregateValidator::isValidCalciteMeasure) && aggregate.getGroupings().stream() @@ -45,7 +45,7 @@ public static boolean isValidCalciteAggregate(Aggregate aggregate) { * @return true if the {@code measure} can be converted to a Calcite equivalent without changes, * false otherwise. */ - private static boolean isValidCalciteMeasure(Aggregate.Measure measure) { + private static boolean isValidCalciteMeasure(final Aggregate.Measure measure) { return // all function arguments to measures must be field references measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg)) @@ -67,7 +67,7 @@ private static boolean isValidCalciteMeasure(Aggregate.Measure measure) { * @return true if the {@code grouping} can be converted to a Calcite equivalent without changes, * false otherwise. */ - private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { + private static boolean isValidCalciteGrouping(final Aggregate.Grouping grouping) { if (!grouping.getExpressions().stream().allMatch(e -> isSimpleFieldReference(e))) { // all grouping expressions must be field references return false; @@ -81,7 +81,7 @@ private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { // For example, if a grouping is defined as (0, 2, 1) in Substrait, Calcite will output it as // (0, 1, 2), which means that the Calcite output will no longer line up with the expectations // of the Substrait plan. - List groupingFields = + final List groupingFields = grouping.getExpressions().stream() // isSimpleFieldReference above guarantees that the expr is a FieldReference .map(expr -> getFieldRefOffset((FieldReference) expr)) @@ -90,20 +90,20 @@ private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) { return isOrdered(groupingFields); } - private static boolean isSimpleFieldReference(FunctionArg e) { + private static boolean isSimpleFieldReference(final FunctionArg e) { if (!(e instanceof FieldReference)) { return false; } - List segments = ((FieldReference) e).segments(); + final List segments = ((FieldReference) e).segments(); return segments.size() == 1 && segments.get(0) instanceof FieldReference.StructField; } - private static int getFieldRefOffset(FieldReference fr) { + private static int getFieldRefOffset(final FieldReference fr) { return ((FieldReference.StructField) fr.segments().get(0)).offset(); } - private static boolean isOrdered(List list) { + private static boolean isOrdered(final List list) { for (int i = 1; i < list.size(); i++) { if (list.get(i - 1) > list.get(i)) { return false; @@ -120,7 +120,7 @@ public static class PreCalciteAggregateTransformer { // Tracks the offset of the next expression added private int expressionOffset; - private PreCalciteAggregateTransformer(Aggregate aggregate) { + private PreCalciteAggregateTransformer(final Aggregate aggregate) { this.newExpressions = new ArrayList<>(); // The Substrait project output includes all input fields, followed by expressions this.expressionOffset = aggregate.getInput().getRecordType().fields().size(); @@ -135,15 +135,15 @@ private PreCalciteAggregateTransformer(Aggregate aggregate) { *
  • Adding all groupings to this project so that they are referenced in "order" * */ - public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) { - PreCalciteAggregateTransformer at = new PreCalciteAggregateTransformer(aggregate); + public static Aggregate transformToValidCalciteAggregate(final Aggregate aggregate) { + final PreCalciteAggregateTransformer at = new PreCalciteAggregateTransformer(aggregate); - List newMeasures = + final List newMeasures = aggregate.getMeasures().stream().map(at::updateMeasure).collect(Collectors.toList()); - List newGroupings = + final List newGroupings = aggregate.getGroupings().stream().map(at::updateGrouping).collect(Collectors.toList()); - Project preAggregateProject = + final Project preAggregateProject = Project.builder().input(aggregate.getInput()).expressions(at.newExpressions).build(); return Aggregate.builder() @@ -154,15 +154,15 @@ public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) { .build(); } - private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { - AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction(); + private Aggregate.Measure updateMeasure(final Aggregate.Measure measure) { + final AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction(); - List newFunctionArgs = + final List newFunctionArgs = oldAggregateFunctionInvocation.arguments().stream() .map(this::projectOutNonFieldReference) .collect(Collectors.toList()); - List newSortFields = + final List newSortFields = oldAggregateFunctionInvocation.sort().stream() .map( sf -> @@ -172,10 +172,10 @@ private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { .build()) .collect(Collectors.toList()); - Optional newPreMeasureFilter = + final Optional newPreMeasureFilter = measure.getPreMeasureFilter().map(this::projectOutNonFieldReference); - AggregateFunctionInvocation newAggregateFunctionInvocation = + final AggregateFunctionInvocation newAggregateFunctionInvocation = AggregateFunctionInvocation.builder() .from(oldAggregateFunctionInvocation) .arguments(newFunctionArgs) @@ -188,15 +188,15 @@ private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { .build(); } - private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { + private Aggregate.Grouping updateGrouping(final Aggregate.Grouping grouping) { // project out all groupings unconditionally, even field references // this ensures that out of order groupings are re-projected into in order groupings - List newGroupingExpressions = + final List newGroupingExpressions = grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList()); return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build(); } - private Expression projectOutNonFieldReference(FunctionArg farg) { + private Expression projectOutNonFieldReference(final FunctionArg farg) { if ((farg instanceof Expression)) { return projectOutNonFieldReference((Expression) farg); } else { @@ -204,7 +204,7 @@ private Expression projectOutNonFieldReference(FunctionArg farg) { } } - private Expression projectOutNonFieldReference(Expression expr) { + private Expression projectOutNonFieldReference(final Expression expr) { if (isSimpleFieldReference(expr)) { return expr; } @@ -216,7 +216,7 @@ private Expression projectOutNonFieldReference(Expression expr) { * PreCalciteAggregateTransformer#expressionOffset} and returns a field reference to the new * expression */ - private Expression projectOut(Expression expr) { + private Expression projectOut(final Expression expr) { newExpressions.add(expr); return FieldReference.builder() // create a field reference to the new expression, then update the expression offset diff --git a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java index 81c4e9a49..508f2f8ed 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java @@ -21,67 +21,67 @@ /** A more generic version of RelShuttle that allows an alternative return value. */ public abstract class RelNodeVisitor { - public OUTPUT visit(TableScan scan) throws EXCEPTION { + public OUTPUT visit(final TableScan scan) throws EXCEPTION { return visitOther(scan); } - public OUTPUT visit(TableFunctionScan scan) throws EXCEPTION { + public OUTPUT visit(final TableFunctionScan scan) throws EXCEPTION { return visitOther(scan); } - public OUTPUT visit(Values values) throws EXCEPTION { + public OUTPUT visit(final Values values) throws EXCEPTION { return visitOther(values); } - public OUTPUT visit(Filter filter) throws EXCEPTION { + public OUTPUT visit(final Filter filter) throws EXCEPTION { return visitOther(filter); } - public OUTPUT visit(Calc calc) throws EXCEPTION { + public OUTPUT visit(final Calc calc) throws EXCEPTION { return visitOther(calc); } - public OUTPUT visit(Project project) throws EXCEPTION { + public OUTPUT visit(final Project project) throws EXCEPTION { return visitOther(project); } - public OUTPUT visit(Join join) throws EXCEPTION { + public OUTPUT visit(final Join join) throws EXCEPTION { return visitOther(join); } - public OUTPUT visit(Correlate correlate) throws EXCEPTION { + public OUTPUT visit(final Correlate correlate) throws EXCEPTION { return visitOther(correlate); } - public OUTPUT visit(Union union) throws EXCEPTION { + public OUTPUT visit(final Union union) throws EXCEPTION { return visitOther(union); } - public OUTPUT visit(Intersect intersect) throws EXCEPTION { + public OUTPUT visit(final Intersect intersect) throws EXCEPTION { return visitOther(intersect); } - public OUTPUT visit(Minus minus) throws EXCEPTION { + public OUTPUT visit(final Minus minus) throws EXCEPTION { return visitOther(minus); } - public OUTPUT visit(Aggregate aggregate) throws EXCEPTION { + public OUTPUT visit(final Aggregate aggregate) throws EXCEPTION { return visitOther(aggregate); } - public OUTPUT visit(Match match) throws EXCEPTION { + public OUTPUT visit(final Match match) throws EXCEPTION { return visitOther(match); } - public OUTPUT visit(Sort sort) throws EXCEPTION { + public OUTPUT visit(final Sort sort) throws EXCEPTION { return visitOther(sort); } - public OUTPUT visit(Exchange exchange) throws EXCEPTION { + public OUTPUT visit(final Exchange exchange) throws EXCEPTION { return visitOther(exchange); } - public OUTPUT visit(TableModify modify) throws EXCEPTION { + public OUTPUT visit(final TableModify modify) throws EXCEPTION { return visitOther(modify); } @@ -93,7 +93,7 @@ protected RelNodeVisitor() {} * The method you call when you would normally call RelNode.accept(visitor). Instead call * RelVisitor.reverseAccept(RelNode) due to the lack of ability to extend base classes. */ - public final OUTPUT reverseAccept(RelNode node) throws EXCEPTION { + public final OUTPUT reverseAccept(final RelNode node) throws EXCEPTION { if (node instanceof TableScan) { return this.visit((TableScan) node); } else if (node instanceof TableFunctionScan) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java index 99eaac1ab..05e0691f1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java @@ -25,7 +25,7 @@ public class SchemaCollector { private final RelDataTypeFactory typeFactory; private final TypeConverter typeConverter; - public SchemaCollector(RelDataTypeFactory typeFactory, TypeConverter typeConverter) { + public SchemaCollector(final RelDataTypeFactory typeFactory, final TypeConverter typeConverter) { this.typeFactory = typeFactory; this.typeConverter = typeConverter; } @@ -39,22 +39,23 @@ public SchemaCollector(RelDataTypeFactory typeFactory, TypeConverter typeConvert */ public CalciteSchema toSchema(@NonNull final Rel rel) { // Create the root schema under which all tables and schemas will be nested. - CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false); + final CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false); - for (Map.Entry, NamedStruct> entry : TableGatherer.gatherTables(rel).entrySet()) { - List names = entry.getKey(); - NamedStruct namedStruct = entry.getValue(); + for (final Map.Entry, NamedStruct> entry : + TableGatherer.gatherTables(rel).entrySet()) { + final List names = entry.getKey(); + final NamedStruct namedStruct = entry.getValue(); // The last name in names is the table name. All others are schema names. - String tableName = names.get(names.size() - 1); + final String tableName = names.get(names.size() - 1); - CalciteSchema schema = + final CalciteSchema schema = Utils.createCalciteSchemaFromNames(rootSchema, names.subList(0, names.size() - 1)); // Create the table if it is not present - CalciteSchema.TableEntry table = schema.getTable(tableName, CASE_SENSITIVE); + final CalciteSchema.TableEntry table = schema.getTable(tableName, CASE_SENSITIVE); if (table == null) { - RelDataType rowType = + final RelDataType rowType = typeConverter.toCalcite(typeFactory, namedStruct.struct(), namedStruct.names()); schema.add(tableName, new SubstraitTable(tableName, rowType)); } @@ -78,19 +79,19 @@ private TableGatherer() { * @param rootRel under which to search for {@link NamedScan}s * @return a map of qualified table names to their associated Substrait schemas */ - public static Map, NamedStruct> gatherTables(Rel rootRel) { - TableGatherer visitor = new TableGatherer(); + public static Map, NamedStruct> gatherTables(final Rel rootRel) { + final TableGatherer visitor = new TableGatherer(); rootRel.accept(visitor, EmptyVisitationContext.INSTANCE); return visitor.tableMap; } @Override - public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { + public Optional visit(final NamedScan namedScan, final EmptyVisitationContext context) { super.visit(namedScan, context); - List tableName = namedScan.getNames(); + final List tableName = namedScan.getNames(); if (tableMap.containsKey(tableName)) { - NamedStruct existingSchema = tableMap.get(tableName); + final NamedStruct existingSchema = tableMap.get(tableName); if (!existingSchema.equals(namedScan.getInitialSchema())) { throw new IllegalArgumentException( String.format( @@ -103,12 +104,12 @@ public Optional visit(NamedScan namedScan, EmptyVisitationContext context) } @Override - public Optional visit(NamedWrite namedWrite, EmptyVisitationContext context) { + public Optional visit(final NamedWrite namedWrite, final EmptyVisitationContext context) { super.visit(namedWrite, context); - List tableName = namedWrite.getNames(); + final List tableName = namedWrite.getNames(); if (tableMap.containsKey(tableName)) { - NamedStruct existingSchema = tableMap.get(tableName); + final NamedStruct existingSchema = tableMap.get(tableName); if (!existingSchema.equals(namedWrite.getTableSchema())) { throw new IllegalArgumentException( String.format( @@ -121,12 +122,13 @@ public Optional visit(NamedWrite namedWrite, EmptyVisitationContext context } @Override - public Optional visit(NamedUpdate namedUpdate, EmptyVisitationContext context) { + public Optional visit( + final NamedUpdate namedUpdate, final EmptyVisitationContext context) { super.visit(namedUpdate, context); - List tableName = namedUpdate.getNames(); + final List tableName = namedUpdate.getNames(); if (tableMap.containsKey(tableName)) { - NamedStruct existingSchema = tableMap.get(tableName); + final NamedStruct existingSchema = tableMap.get(tableName); if (!existingSchema.equals(namedUpdate.getTableSchema())) { throw new IllegalArgumentException( String.format( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index 671deabe5..47fc6c56f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -36,16 +36,16 @@ public class SqlConverterBase { protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); final FeatureBoard featureBoard; - protected SqlConverterBase(FeatureBoard features) { + protected SqlConverterBase(final FeatureBoard features) { this.factory = SubstraitTypeSystem.TYPE_FACTORY; this.config = CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); this.converterConfig = SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false); - VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.of("hello")); + final VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.of("hello")); this.relOptCluster = RelOptCluster.create(planner, new RexBuilder(factory)); relOptCluster.setMetadataQuerySupplier( () -> { - ProxyingMetadataHandlerProvider handler = + final ProxyingMetadataHandlerProvider handler = new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE); return new RelMetadataQuery(handler); }); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index c32fab07c..ab4b1eb10 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -38,9 +38,9 @@ public SqlExpressionToSubstrait() { } public SqlExpressionToSubstrait( - FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { + final FeatureBoard features, final SimpleExtension.ExtensionCollection extensions) { super(features); - ScalarFunctionConverter scalarFunctionConverter = + final ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter(extensions.scalarFunctions(), factory); this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); } @@ -52,10 +52,10 @@ private static final class Result { final Map nameToNodeMap; Result( - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) { + final SqlValidator validator, + final CalciteCatalogReader catalogReader, + final Map nameToTypeMap, + final Map nameToNodeMap) { this.validator = validator; this.catalogReader = catalogReader; this.nameToTypeMap = nameToTypeMap; @@ -72,7 +72,7 @@ private static final class Result { * @throws SqlParseException */ public io.substrait.proto.ExtendedExpression convert( - String sqlExpression, List createStatements) throws SqlParseException { + final String sqlExpression, final List createStatements) throws SqlParseException { return convert(new String[] {sqlExpression}, createStatements); } @@ -85,8 +85,8 @@ public io.substrait.proto.ExtendedExpression convert( * @throws SqlParseException */ public io.substrait.proto.ExtendedExpression convert( - String[] sqlExpressions, List createStatements) throws SqlParseException { - Result result = registerCreateTablesForExtendedExpression(createStatements); + final String[] sqlExpressions, final List createStatements) throws SqlParseException { + final Result result = registerCreateTablesForExtendedExpression(createStatements); return executeInnerSQLExpressions( sqlExpressions, result.validator, @@ -96,28 +96,28 @@ public io.substrait.proto.ExtendedExpression convert( } private io.substrait.proto.ExtendedExpression executeInnerSQLExpressions( - String[] sqlExpressions, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) + final String[] sqlExpressions, + final SqlValidator validator, + final CalciteCatalogReader catalogReader, + final Map nameToTypeMap, + final Map nameToNodeMap) throws SqlParseException { int columnIndex = 1; - List expressionReferences = new ArrayList<>(); + final List expressionReferences = new ArrayList<>(); RexNode rexNode; - for (String sqlExpression : sqlExpressions) { + for (final String sqlExpression : sqlExpressions) { rexNode = sqlToRexNode( sqlExpression.trim(), validator, catalogReader, nameToTypeMap, nameToNodeMap); - ExtendedExpression.ExpressionReference expressionReference = + final ExtendedExpression.ExpressionReference expressionReference = ExtendedExpression.ExpressionReference.builder() .expression(rexNode.accept(this.rexConverter)) .addOutputNames("column-" + columnIndex++) .build(); expressionReferences.add(expressionReference); } - NamedStruct namedStruct = toNamedStruct(nameToTypeMap); - Builder extendedExpression = + final NamedStruct namedStruct = toNamedStruct(nameToTypeMap); + final Builder extendedExpression = ExtendedExpression.builder() .referredExpressions(expressionReferences) .baseSchema(namedStruct); @@ -126,16 +126,16 @@ private io.substrait.proto.ExtendedExpression executeInnerSQLExpressions( } private RexNode sqlToRexNode( - String sql, - SqlValidator validator, - CalciteCatalogReader catalogReader, - Map nameToTypeMap, - Map nameToNodeMap) + final String sql, + final SqlValidator validator, + final CalciteCatalogReader catalogReader, + final Map nameToTypeMap, + final Map nameToNodeMap) throws SqlParseException { - SqlParser parser = SqlParser.create(sql, parserConfig); - SqlNode sqlNode = parser.parseExpression(); - SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); - SqlToRelConverter converter = + final SqlParser parser = SqlParser.create(sql, parserConfig); + final SqlNode sqlNode = parser.parseExpression(); + final SqlNode validSqlNode = validator.validateParameterizedExpression(sqlNode, nameToTypeMap); + final SqlToRelConverter converter = new SqlToRelConverter( null, validator, @@ -146,20 +146,20 @@ private RexNode sqlToRexNode( return converter.convertExpression(validSqlNode, nameToNodeMap); } - private Result registerCreateTablesForExtendedExpression(List tables) + private Result registerCreateTablesForExtendedExpression(final List tables) throws SqlParseException { - Map nameToTypeMap = new LinkedHashMap<>(); - Map nameToNodeMap = new HashMap<>(); - CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); - CalciteCatalogReader catalogReader = + final Map nameToTypeMap = new LinkedHashMap<>(); + final Map nameToNodeMap = new HashMap<>(); + final CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + final CalciteCatalogReader catalogReader = new CalciteCatalogReader(rootSchema, List.of(), factory, config); if (tables != null) { - for (String tableDef : tables) { - List tList = + for (final String tableDef : tables) { + final List tList = SubstraitCreateStatementParser.processCreateStatements(tableDef); - for (SubstraitTable t : tList) { + for (final SubstraitTable t : tList) { rootSchema.add(t.getName(), t); - for (RelDataTypeField field : t.getRowType(factory).getFieldList()) { + for (final RelDataTypeField field : t.getRowType(factory).getFieldList()) { nameToTypeMap.merge( // to validate the sql expression tree field.getName(), field.getType(), @@ -178,16 +178,16 @@ private Result registerCreateTablesForExtendedExpression(List tables) } } } - SqlValidator validator = new SubstraitSqlValidator(catalogReader); + final SqlValidator validator = new SubstraitSqlValidator(catalogReader); return new Result(validator, catalogReader, nameToTypeMap, nameToNodeMap); } - private NamedStruct toNamedStruct(Map nameToTypeMap) { - ArrayList names = new ArrayList(); - ArrayList types = new ArrayList(); - for (Map.Entry entry : nameToTypeMap.entrySet()) { - String k = entry.getKey(); - RelDataType v = entry.getValue(); + private NamedStruct toNamedStruct(final Map nameToTypeMap) { + final ArrayList names = new ArrayList(); + final ArrayList types = new ArrayList(); + for (final Map.Entry entry : nameToTypeMap.entrySet()) { + final String k = entry.getKey(); + final RelDataType v = entry.getValue(); names.add(k); types.add(TypeConverter.DEFAULT.toSubstrait(v)); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java b/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java index ef2da19f0..f9e0dfac8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java @@ -46,36 +46,40 @@ public class SqlKindFromRel private static final SqlKind QUERY_KIND = SqlKind.SELECT; @Override - public SqlKind visit(Aggregate aggregate, EmptyVisitationContext context) + public SqlKind visit(final Aggregate aggregate, final EmptyVisitationContext context) throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(EmptyScan emptyScan, EmptyVisitationContext context) + public SqlKind visit(final EmptyScan emptyScan, final EmptyVisitationContext context) throws RuntimeException { // An empty scan is typically the result of a query that returns no rows. return QUERY_KIND; } @Override - public SqlKind visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Fetch fetch, final EmptyVisitationContext context) + throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Filter filter, final EmptyVisitationContext context) + throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(Join join, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Join join, final EmptyVisitationContext context) + throws RuntimeException { return SqlKind.JOIN; } @Override - public SqlKind visit(Set set, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Set set, final EmptyVisitationContext context) + throws RuntimeException { switch (set.getSetOp()) { case UNION_ALL: case UNION_DISTINCT: @@ -95,94 +99,102 @@ public SqlKind visit(Set set, EmptyVisitationContext context) throws RuntimeExce } @Override - public SqlKind visit(NamedScan namedScan, EmptyVisitationContext context) + public SqlKind visit(final NamedScan namedScan, final EmptyVisitationContext context) throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(LocalFiles localFiles, EmptyVisitationContext context) + public SqlKind visit(final LocalFiles localFiles, final EmptyVisitationContext context) throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(Project project, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Project project, final EmptyVisitationContext context) + throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Expand expand, final EmptyVisitationContext context) + throws RuntimeException { return QUERY_KIND; } @Override - public SqlKind visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Sort sort, final EmptyVisitationContext context) + throws RuntimeException { return SqlKind.ORDER_BY; } @Override - public SqlKind visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final Cross cross, final EmptyVisitationContext context) + throws RuntimeException { return SqlKind.JOIN; } @Override - public SqlKind visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) + public SqlKind visit( + final VirtualTableScan virtualTableScan, final EmptyVisitationContext context) throws RuntimeException { // A virtual table scan corresponds to a VALUES clause. return SqlKind.VALUES; } @Override - public SqlKind visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) + public SqlKind visit(final ExtensionLeaf extensionLeaf, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER; } @Override - public SqlKind visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + public SqlKind visit(final ExtensionSingle extensionSingle, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER; } @Override - public SqlKind visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + public SqlKind visit(final ExtensionMulti extensionMulti, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER; } @Override - public SqlKind visit(ExtensionTable extensionTable, EmptyVisitationContext context) + public SqlKind visit(final ExtensionTable extensionTable, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER; } @Override - public SqlKind visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final HashJoin hashJoin, final EmptyVisitationContext context) + throws RuntimeException { return SqlKind.JOIN; } @Override - public SqlKind visit(MergeJoin mergeJoin, EmptyVisitationContext context) + public SqlKind visit(final MergeJoin mergeJoin, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.JOIN; } @Override - public SqlKind visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) + public SqlKind visit(final NestedLoopJoin nestedLoopJoin, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.JOIN; } @Override public SqlKind visit( - ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + final ConsistentPartitionWindow consistentPartitionWindow, + final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OVER; } @Override - public SqlKind visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final NamedWrite write, final EmptyVisitationContext context) + throws RuntimeException { switch (write.getOperation()) { case INSERT: return SqlKind.INSERT; @@ -198,13 +210,14 @@ public SqlKind visit(NamedWrite write, EmptyVisitationContext context) throws Ru } @Override - public SqlKind visit(ExtensionWrite write, EmptyVisitationContext context) + public SqlKind visit(final ExtensionWrite write, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER_DDL; } @Override - public SqlKind visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final NamedDdl ddl, final EmptyVisitationContext context) + throws RuntimeException { switch (ddl.getOperation()) { case CREATE: case CREATE_OR_REPLACE: @@ -234,41 +247,43 @@ public SqlKind visit(NamedDdl ddl, EmptyVisitationContext context) throws Runtim } @Override - public SqlKind visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final ExtensionDdl ddl, final EmptyVisitationContext context) + throws RuntimeException { return SqlKind.OTHER_DDL; } @Override - public SqlKind visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { + public SqlKind visit(final NamedUpdate update, final EmptyVisitationContext context) + throws RuntimeException { return SqlKind.UPDATE; } @Override - public SqlKind visit(ScatterExchange exchange, EmptyVisitationContext context) + public SqlKind visit(final ScatterExchange exchange, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER_DDL; } @Override - public SqlKind visit(SingleBucketExchange exchange, EmptyVisitationContext context) + public SqlKind visit(final SingleBucketExchange exchange, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER_DDL; } @Override - public SqlKind visit(MultiBucketExchange exchange, EmptyVisitationContext context) + public SqlKind visit(final MultiBucketExchange exchange, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER_DDL; } @Override - public SqlKind visit(RoundRobinExchange exchange, EmptyVisitationContext context) + public SqlKind visit(final RoundRobinExchange exchange, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER_DDL; } @Override - public SqlKind visit(BroadcastExchange exchange, EmptyVisitationContext context) + public SqlKind visit(final BroadcastExchange exchange, final EmptyVisitationContext context) throws RuntimeException { return SqlKind.OTHER_DDL; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 3e19ca58c..59893dda5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -15,7 +15,7 @@ public SqlToSubstrait() { this(null); } - public SqlToSubstrait(FeatureBoard features) { + public SqlToSubstrait(final FeatureBoard features) { super(features); } @@ -32,9 +32,10 @@ public SqlToSubstrait(FeatureBoard features) { * {@link PlanProtoConverter#toProto(Plan)} */ @Deprecated - public io.substrait.proto.Plan execute(String sqlStatements, Prepare.CatalogReader catalogReader) + public io.substrait.proto.Plan execute( + final String sqlStatements, final Prepare.CatalogReader catalogReader) throws SqlParseException { - PlanProtoConverter planToProto = new PlanProtoConverter(); + final PlanProtoConverter planToProto = new PlanProtoConverter(); return planToProto.toProto(convert(sqlStatements, catalogReader)); } @@ -47,9 +48,9 @@ public io.substrait.proto.Plan execute(String sqlStatements, Prepare.CatalogRead * @return the Substrait {@link Plan} * @throws SqlParseException if there is an error while parsing the SQL statements */ - public Plan convert(String sqlStatements, Prepare.CatalogReader catalogReader) + public Plan convert(final String sqlStatements, final Prepare.CatalogReader catalogReader) throws SqlParseException { - Builder builder = io.substrait.plan.Plan.builder(); + final Builder builder = io.substrait.plan.Plan.builder(); builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build()); // TODO: consider case in which one sql passes conversion while others don't diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 25f128a2d..052701653 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -105,9 +105,9 @@ public class SubstraitRelNodeConverter private final TypeConverter typeConverter; public SubstraitRelNodeConverter( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - RelBuilder relBuilder) { + final SimpleExtension.ExtensionCollection extensions, + final RelDataTypeFactory typeFactory, + final RelBuilder relBuilder) { this( typeFactory, relBuilder, @@ -118,12 +118,12 @@ public SubstraitRelNodeConverter( } public SubstraitRelNodeConverter( - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter) { + final RelDataTypeFactory typeFactory, + final RelBuilder relBuilder, + final ScalarFunctionConverter scalarFunctionConverter, + final AggregateFunctionConverter aggregateFunctionConverter, + final WindowFunctionConverter windowFunctionConverter, + final TypeConverter typeConverter) { this( typeFactory, relBuilder, @@ -136,13 +136,13 @@ public SubstraitRelNodeConverter( } public SubstraitRelNodeConverter( - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter, - ExpressionRexConverter expressionRexConverter) { + final RelDataTypeFactory typeFactory, + final RelBuilder relBuilder, + final ScalarFunctionConverter scalarFunctionConverter, + final AggregateFunctionConverter aggregateFunctionConverter, + final WindowFunctionConverter windowFunctionConverter, + final TypeConverter typeConverter, + final ExpressionRexConverter expressionRexConverter) { this.typeFactory = typeFactory; this.typeConverter = typeConverter; this.relBuilder = relBuilder; @@ -154,11 +154,11 @@ public SubstraitRelNodeConverter( } public static RelNode convert( - Rel relRoot, - RelOptCluster relOptCluster, - Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig) { - RelBuilder relBuilder = + final Rel relRoot, + final RelOptCluster relOptCluster, + final Prepare.CatalogReader catalogReader, + final SqlParser.Config parserConfig) { + final RelBuilder relBuilder = RelBuilder.create( Frameworks.newConfigBuilder() .parserConfig(parserConfig) @@ -174,51 +174,51 @@ public static RelNode convert( } @Override - public RelNode visit(Filter filter, Context context) throws RuntimeException { - RelNode input = filter.getInput().accept(this, context); + public RelNode visit(final Filter filter, final Context context) throws RuntimeException { + final RelNode input = filter.getInput().accept(this, context); context.pushOuterRowType(input.getRowType()); - RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); - RelNode node = + final RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); + final RelNode node = relBuilder.push(input).filter(context.popCorrelationIds(), filterCondition).build(); context.popOuterRowType(); return applyRemap(node, filter.getRemap()); } @Override - public RelNode visit(NamedScan namedScan, Context context) throws RuntimeException { - RelNode node = relBuilder.scan(namedScan.getNames()).build(); + public RelNode visit(final NamedScan namedScan, final Context context) throws RuntimeException { + final RelNode node = relBuilder.scan(namedScan.getNames()).build(); return applyRemap(node, namedScan.getRemap()); } @Override - public RelNode visit(LocalFiles localFiles, Context context) throws RuntimeException { + public RelNode visit(final LocalFiles localFiles, final Context context) throws RuntimeException { return visitFallback(localFiles, context); } @Override - public RelNode visit(EmptyScan emptyScan, Context context) throws RuntimeException { - RelDataType rowType = + public RelNode visit(final EmptyScan emptyScan, final Context context) throws RuntimeException { + final RelDataType rowType = typeConverter.toCalcite(relBuilder.getTypeFactory(), emptyScan.getInitialSchema().struct()); - RelNode node = LogicalValues.create(relBuilder.getCluster(), rowType, ImmutableList.of()); + final RelNode node = LogicalValues.create(relBuilder.getCluster(), rowType, ImmutableList.of()); return applyRemap(node, emptyScan.getRemap()); } @Override - public RelNode visit(Project project, Context context) throws RuntimeException { - RelNode child = project.getInput().accept(this, context); + public RelNode visit(final Project project, final Context context) throws RuntimeException { + final RelNode child = project.getInput().accept(this, context); context.pushOuterRowType(child.getRowType()); - Stream directOutputs = + final Stream directOutputs = IntStream.range(0, child.getRowType().getFieldCount()) .mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex)); - Stream exprs = + final Stream exprs = project.getExpressions().stream().map(expr -> expr.accept(expressionRexConverter, context)); - List rexExprs = + final List rexExprs = Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList()); - RelNode node = + final RelNode node = relBuilder .push(child) .project(rexExprs, List.of(), false, context.popCorrelationIds()) @@ -228,26 +228,26 @@ public RelNode visit(Project project, Context context) throws RuntimeException { } @Override - public RelNode visit(Cross cross, Context context) throws RuntimeException { - RelNode left = cross.getLeft().accept(this, context); - RelNode right = cross.getRight().accept(this, context); + public RelNode visit(final Cross cross, final Context context) throws RuntimeException { + final RelNode left = cross.getLeft().accept(this, context); + final RelNode right = cross.getRight().accept(this, context); // Calcite represents CROSS JOIN as the equivalent INNER JOIN with true condition - RelNode node = + final RelNode node = relBuilder.push(left).push(right).join(JoinRelType.INNER, relBuilder.literal(true)).build(); return applyRemap(node, cross.getRemap()); } @Override - public RelNode visit(Join join, Context context) throws RuntimeException { - RelNode left = join.getLeft().accept(this, context); - RelNode right = join.getRight().accept(this, context); + public RelNode visit(final Join join, final Context context) throws RuntimeException { + final RelNode left = join.getLeft().accept(this, context); + final RelNode right = join.getRight().accept(this, context); context.pushOuterRowType(left.getRowType(), right.getRowType()); - RexNode condition = + final RexNode condition = join.getCondition() .map(c -> c.accept(expressionRexConverter, context)) .orElse(relBuilder.literal(true)); - JoinRelType joinType = asJoinRelType(join); - RelNode node = + final JoinRelType joinType = asJoinRelType(join); + final RelNode node = relBuilder .push(left) .push(right) @@ -257,8 +257,8 @@ public RelNode visit(Join join, Context context) throws RuntimeException { return applyRemap(node, join.getRemap()); } - private JoinRelType asJoinRelType(Join join) { - Join.JoinType type = join.getJoinType(); + private JoinRelType asJoinRelType(final Join join) { + final Join.JoinType type = join.getJoinType(); if (type == JoinType.INNER) { return JoinRelType.INNER; @@ -292,7 +292,7 @@ private JoinRelType asJoinRelType(Join join) { } @Override - public RelNode visit(Set set, Context context) throws RuntimeException { + public RelNode visit(final Set set, final Context context) throws RuntimeException { set.getInputs() .forEach( input -> { @@ -302,13 +302,13 @@ public RelNode visit(Set set, Context context) throws RuntimeException { // correspond to the Calcite relations they are associated with. They are retained for now // to enable users to migrate off of them. // See: https://github.com/substrait-io/substrait-java/issues/303 - RelBuilder builder = getRelBuilder(set); - RelNode node = builder.build(); + final RelBuilder builder = getRelBuilder(set); + final RelNode node = builder.build(); return applyRemap(node, set.getRemap()); } - private RelBuilder getRelBuilder(Set set) { - int numInputs = set.getInputs().size(); + private RelBuilder getRelBuilder(final Set set) { + final int numInputs = set.getInputs().size(); switch (set.getSetOp()) { case MINUS_PRIMARY: @@ -333,15 +333,15 @@ private RelBuilder getRelBuilder(Set set) { } @Override - public RelNode visit(Aggregate aggregate, Context context) throws RuntimeException { + public RelNode visit(Aggregate aggregate, final Context context) throws RuntimeException { if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) { aggregate = PreCalciteAggregateValidator.PreCalciteAggregateTransformer .transformToValidCalciteAggregate(aggregate); } - RelNode child = aggregate.getInput().accept(this, context); - List> groupExprLists = + final RelNode child = aggregate.getInput().accept(this, context); + final List> groupExprLists = aggregate.getGroupings().stream() .map( gr -> @@ -349,11 +349,11 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti .map(expr -> expr.accept(expressionRexConverter, context)) .collect(java.util.stream.Collectors.toList())) .collect(java.util.stream.Collectors.toList()); - List groupExprs = + final List groupExprs = groupExprLists.stream().flatMap(Collection::stream).collect(Collectors.toList()); - RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupExprs, groupExprLists); + final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupExprs, groupExprLists); - List aggregateCalls = + final List aggregateCalls = aggregate.getMeasures().stream() .map(measure -> fromMeasure(measure, context)) .collect(java.util.stream.Collectors.toList()); @@ -381,7 +381,7 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti null)); final int groupingCallIndex = aggregateCalls.size() - 1; if (groupingSetIndexGetsRemapped) { - List remapList = new LinkedList<>(remap.get().indices()); + final List remapList = new LinkedList<>(remap.get().indices()); for (int i = 0; i < remapList.size(); i++) { if (remapList.get(i).equals(lastFieldIndex)) { // replace last field index with field index of the GROUP_ID() function call @@ -392,13 +392,13 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti } } - RelNode node = relBuilder.push(child).aggregate(groupKey, aggregateCalls).build(); + final RelNode node = relBuilder.push(child).aggregate(groupKey, aggregateCalls).build(); return applyRemap(node, remap); } - private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { - List eArgs = measure.getFunction().arguments(); - List arguments = + private AggregateCall fromMeasure(final Aggregate.Measure measure, final Context context) { + final List eArgs = measure.getFunction().arguments(); + final List arguments = IntStream.range(0, measure.getFunction().arguments().size()) .mapToObj( i -> @@ -410,7 +410,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { expressionRexConverter, context)) .collect(java.util.stream.Collectors.toList()); - Optional operator = + final Optional operator = aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc( measure.getFunction().declaration().key(), measure.getFunction().outputType()); if (!operator.isPresent()) { @@ -418,23 +418,24 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { String.format( "Unable to find binding for call %s", measure.getFunction().declaration().name())); } - List argIndex = new ArrayList<>(); - for (RexNode arg : arguments) { + final List argIndex = new ArrayList<>(); + for (final RexNode arg : arguments) { // arguments are guaranteed to be RexInputRef because of the prior call to // transformToValidCalciteAggregate argIndex.add(((RexInputRef) arg).getIndex()); } - boolean distinct = + final boolean distinct = measure.getFunction().invocation().equals(Expression.AggregationInvocation.DISTINCT); - SqlAggFunction aggFunction; - RelDataType returnType = typeConverter.toCalcite(typeFactory, measure.getFunction().getType()); + final SqlAggFunction aggFunction; + final RelDataType returnType = + typeConverter.toCalcite(typeFactory, measure.getFunction().getType()); if (operator.get() instanceof SqlAggFunction) { aggFunction = (SqlAggFunction) operator.get(); } else { - String msg = + final String msg = String.format( "Unable to convert non-aggregate operator: %s for substrait aggregate function %s", operator.get(), measure.getFunction().declaration().name()); @@ -443,7 +444,8 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { int filterArg = -1; if (measure.getPreMeasureFilter().isPresent()) { - RexNode filter = measure.getPreMeasureFilter().get().accept(expressionRexConverter, context); + final RexNode filter = + measure.getPreMeasureFilter().get().accept(expressionRexConverter, context); filterArg = ((RexInputRef) filter).getIndex(); } @@ -471,20 +473,20 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { } @Override - public RelNode visit(Sort sort, Context context) throws RuntimeException { - RelNode child = sort.getInput().accept(this, context); - List sortExpressions = + public RelNode visit(final Sort sort, final Context context) throws RuntimeException { + final RelNode child = sort.getInput().accept(this, context); + final List sortExpressions = sort.getSortFields().stream() .map(sortField -> directedRexNode(sortField, context)) .collect(Collectors.toList()); - RelNode node = relBuilder.push(child).sort(sortExpressions).build(); + final RelNode node = relBuilder.push(child).sort(sortExpressions).build(); return applyRemap(node, sort.getRemap()); } - private RexNode directedRexNode(Expression.SortField sortField, Context context) { - Expression expression = sortField.expr(); - RexNode rexNode = expression.accept(expressionRexConverter, context); - SortDirection sortDirection = sortField.direction(); + private RexNode directedRexNode(final Expression.SortField sortField, final Context context) { + final Expression expression = sortField.expr(); + final RexNode rexNode = expression.accept(expressionRexConverter, context); + final SortDirection sortDirection = sortField.direction(); if (sortDirection == Expression.SortDirection.ASC_NULLS_FIRST) { return relBuilder.nullsFirst(rexNode); @@ -507,11 +509,11 @@ private RexNode directedRexNode(Expression.SortField sortField, Context context) } @Override - public RelNode visit(Fetch fetch, Context context) throws RuntimeException { - RelNode child = fetch.getInput().accept(this, context); - OptionalLong optCount = fetch.getCount(); - long count = optCount.orElse(-1L); - long offset = fetch.getOffset(); + public RelNode visit(final Fetch fetch, final Context context) throws RuntimeException { + final RelNode child = fetch.getInput().accept(this, context); + final OptionalLong optCount = fetch.getCount(); + final long count = optCount.orElse(-1L); + final long offset = fetch.getOffset(); if (offset > Integer.MAX_VALUE) { throw new IllegalArgumentException( String.format("offset is overflowed as an integer: %d", offset)); @@ -520,16 +522,17 @@ public RelNode visit(Fetch fetch, Context context) throws RuntimeException { throw new IllegalArgumentException( String.format("count is overflowed as an integer: %d", count)); } - RelNode node = relBuilder.push(child).limit((int) offset, (int) count).build(); + final RelNode node = relBuilder.push(child).limit((int) offset, (int) count).build(); return applyRemap(node, fetch.getRemap()); } - private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Context context) { - Expression expression = sortField.expr(); - RexNode rex = expression.accept(expressionRexConverter, context); - SortDirection sortDirection = sortField.direction(); - RexSlot rexSlot = (RexSlot) rex; - int fieldIndex = rexSlot.getIndex(); + private RelFieldCollation toRelFieldCollation( + final Expression.SortField sortField, final Context context) { + final Expression expression = sortField.expr(); + final RexNode rex = expression.accept(expressionRexConverter, context); + final SortDirection sortDirection = sortField.direction(); + final RexSlot rexSlot = (RexSlot) rex; + final int fieldIndex = rexSlot.getIndex(); final RelFieldCollation.Direction fieldDirection; final RelFieldCollation.NullDirection nullDirection; @@ -558,19 +561,19 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Co } @Override - public RelNode visit(NamedUpdate update, Context context) { + public RelNode visit(final NamedUpdate update, final Context context) { relBuilder.scan(update.getNames()); - RexNode condition = update.getCondition().accept(expressionRexConverter, context); + final RexNode condition = update.getCondition().accept(expressionRexConverter, context); relBuilder.filter(condition); - RelNode inputForModify = relBuilder.build(); + final RelNode inputForModify = relBuilder.build(); - NamedStruct tableSchema = update.getTableSchema(); - List fieldNames = tableSchema.names(); + final NamedStruct tableSchema = update.getTableSchema(); + final List fieldNames = tableSchema.names(); - List updateColumnList = new ArrayList<>(); - List sourceExpressionList = new ArrayList<>(); + final List updateColumnList = new ArrayList<>(); + final List sourceExpressionList = new ArrayList<>(); - for (AbstractUpdate.TransformExpression transform : update.getTransformations()) { + for (final AbstractUpdate.TransformExpression transform : update.getTransformations()) { updateColumnList.add(fieldNames.get(transform.getColumnTarget())); sourceExpressionList.add( @@ -597,32 +600,37 @@ public RelNode visit(NamedUpdate update, Context context) { } @Override - public RelNode visit(ScatterExchange exchange, Context context) throws RuntimeException { + public RelNode visit(final ScatterExchange exchange, final Context context) + throws RuntimeException { return visitFallback(exchange, context); } @Override - public RelNode visit(SingleBucketExchange exchange, Context context) throws RuntimeException { + public RelNode visit(final SingleBucketExchange exchange, final Context context) + throws RuntimeException { return visitFallback(exchange, context); } @Override - public RelNode visit(MultiBucketExchange exchange, Context context) throws RuntimeException { + public RelNode visit(final MultiBucketExchange exchange, final Context context) + throws RuntimeException { return visitFallback(exchange, context); } @Override - public RelNode visit(RoundRobinExchange exchange, Context context) throws RuntimeException { + public RelNode visit(final RoundRobinExchange exchange, final Context context) + throws RuntimeException { return visitFallback(exchange, context); } @Override - public RelNode visit(BroadcastExchange exchange, Context context) throws RuntimeException { + public RelNode visit(final BroadcastExchange exchange, final Context context) + throws RuntimeException { return visitFallback(exchange, context); } @Override - public RelNode visit(NamedDdl namedDdl, Context context) { + public RelNode visit(final NamedDdl namedDdl, final Context context) { if (namedDdl.getOperation() != AbstractDdlRel.DdlOp.CREATE || namedDdl.getObject() != AbstractDdlRel.DdlObject.VIEW) { throw new UnsupportedOperationException( @@ -638,13 +646,13 @@ public RelNode visit(NamedDdl namedDdl, Context context) { throw new IllegalArgumentException("NamedDdl view definition must be set"); } - Rel viewDefinition = namedDdl.getViewDefinition().get(); - RelNode relNode = viewDefinition.accept(this, context); + final Rel viewDefinition = namedDdl.getViewDefinition().get(); + final RelNode relNode = viewDefinition.accept(this, context); return new CreateView(namedDdl.getNames(), relNode); } @Override - public RelNode visit(VirtualTableScan virtualTableScan, Context context) { + public RelNode visit(final VirtualTableScan virtualTableScan, final Context context) { final RelDataType typeInfoOnly = typeConverter.toCalcite(typeFactory, virtualTableScan.getInitialSchema().struct()); @@ -679,7 +687,7 @@ public RelNode visit(VirtualTableScan virtualTableScan, Context context) { relBuilder.getCluster(), rowTypeWithNames, ImmutableList.copyOf(tuples)); } - private RelNode handleCreateTableAs(NamedWrite namedWrite, Context context) { + private RelNode handleCreateTableAs(final NamedWrite namedWrite, final Context context) { if (namedWrite.getCreateMode() != AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS || namedWrite.getOutputMode() != AbstractWriteRel.OutputMode.NO_OUTPUT) { throw new UnsupportedOperationException( @@ -691,19 +699,19 @@ private RelNode handleCreateTableAs(NamedWrite namedWrite, Context context) { namedWrite.getOutputMode())); } - Rel input = namedWrite.getInput(); - RelNode relNode = input.accept(this, context); + final Rel input = namedWrite.getInput(); + final RelNode relNode = input.accept(this, context); return new CreateTable(namedWrite.getNames(), relNode); } @Override - public RelNode visit(NamedWrite write, Context context) { - RelNode input = write.getInput().accept(this, context); + public RelNode visit(final NamedWrite write, final Context context) { + final RelNode input = write.getInput().accept(this, context); assert relBuilder.getRelOptSchema() != null; final RelOptTable targetTable = relBuilder.getRelOptSchema().getTableForMember(write.getNames()); - TableModify.Operation operation; + final TableModify.Operation operation; switch (write.getOperation()) { case INSERT: operation = TableModify.Operation.INSERT; @@ -734,28 +742,28 @@ public RelNode visit(NamedWrite write, Context context) { } @Override - public RelNode visitFallback(Rel rel, Context context) throws RuntimeException { + public RelNode visitFallback(final Rel rel, final Context context) throws RuntimeException { throw new UnsupportedOperationException( String.format( "Rel %s of type %s not handled by visitor type %s.", rel, rel.getClass().getCanonicalName(), this.getClass().getCanonicalName())); } - protected RelNode applyRemap(RelNode relNode, Optional remap) { + protected RelNode applyRemap(final RelNode relNode, final Optional remap) { if (remap.isPresent()) { return applyRemap(relNode, remap.get()); } return relNode; } - private RelNode applyRemap(RelNode relNode, Rel.Remap remap) { - RelDataType rowType = relNode.getRowType(); - List fieldNames = rowType.getFieldNames(); - List rexList = + private RelNode applyRemap(final RelNode relNode, final Rel.Remap remap) { + final RelDataType rowType = relNode.getRowType(); + final List fieldNames = rowType.getFieldNames(); + final List rexList = remap.indices().stream() .map( index -> { - RelDataTypeField t = rowType.getField(fieldNames.get(index), true, false); + final RelDataTypeField t = rowType.getField(fieldNames.get(index), true, false); return new RexInputRef(index, t.getValue()); }) .collect(java.util.stream.Collectors.toList()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 181986289..84cc51edc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -81,22 +81,22 @@ public class SubstraitRelVisitor extends RelNodeVisitor { private Map fieldAccessDepthMap; public SubstraitRelVisitor( - RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { + final RelDataTypeFactory typeFactory, final SimpleExtension.ExtensionCollection extensions) { this(typeFactory, extensions, FEATURES_DEFAULT); } public SubstraitRelVisitor( - RelDataTypeFactory typeFactory, - SimpleExtension.ExtensionCollection extensions, - FeatureBoard features) { + final RelDataTypeFactory typeFactory, + final SimpleExtension.ExtensionCollection extensions, + final FeatureBoard features) { this.typeConverter = TypeConverter.DEFAULT; - ArrayList converters = new ArrayList(); + final ArrayList converters = new ArrayList(); converters.addAll(CallConverters.defaults(typeConverter)); converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)); converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); this.aggregateFunctionConverter = new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); - WindowFunctionConverter windowFunctionConverter = + final WindowFunctionConverter windowFunctionConverter = new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); @@ -104,13 +104,13 @@ public SubstraitRelVisitor( } public SubstraitRelVisitor( - RelDataTypeFactory typeFactory, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter, - FeatureBoard features) { - ArrayList converters = new ArrayList(); + final RelDataTypeFactory typeFactory, + final ScalarFunctionConverter scalarFunctionConverter, + final AggregateFunctionConverter aggregateFunctionConverter, + final WindowFunctionConverter windowFunctionConverter, + final TypeConverter typeConverter, + final FeatureBoard features) { + final ArrayList converters = new ArrayList(); converters.addAll(CallConverters.defaults(typeConverter)); converters.add(scalarFunctionConverter); converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); @@ -121,13 +121,13 @@ public SubstraitRelVisitor( this.featureBoard = features; } - protected Expression toExpression(RexNode node) { + protected Expression toExpression(final RexNode node) { return node.accept(rexExpressionConverter); } @Override - public Rel visit(org.apache.calcite.rel.core.TableScan scan) { - NamedStruct type = typeConverter.toNamedStruct(scan.getRowType()); + public Rel visit(final org.apache.calcite.rel.core.TableScan scan) { + final NamedStruct type = typeConverter.toNamedStruct(scan.getRowType()); return NamedScan.builder() .initialSchema(type) .addAllNames(scan.getTable().getQualifiedName()) @@ -135,23 +135,23 @@ public Rel visit(org.apache.calcite.rel.core.TableScan scan) { } @Override - public Rel visit(org.apache.calcite.rel.core.TableFunctionScan scan) { + public Rel visit(final org.apache.calcite.rel.core.TableFunctionScan scan) { return super.visit(scan); } @Override - public Rel visit(org.apache.calcite.rel.core.Values values) { - NamedStruct type = typeConverter.toNamedStruct(values.getRowType()); + public Rel visit(final org.apache.calcite.rel.core.Values values) { + final NamedStruct type = typeConverter.toNamedStruct(values.getRowType()); if (values.getTuples().isEmpty()) { return EmptyScan.builder().initialSchema(type).build(); } - LiteralConverter literalConverter = new LiteralConverter(typeConverter); - List structs = + final LiteralConverter literalConverter = new LiteralConverter(typeConverter); + final List structs = values.getTuples().stream() .map( list -> { - List fields = + final List fields = list.stream() .map(l -> literalConverter.convert(l)) .collect(Collectors.toUnmodifiableList()); @@ -162,19 +162,19 @@ public Rel visit(org.apache.calcite.rel.core.Values values) { } @Override - public Rel visit(org.apache.calcite.rel.core.Filter filter) { - Expression condition = toExpression(filter.getCondition()); + public Rel visit(final org.apache.calcite.rel.core.Filter filter) { + final Expression condition = toExpression(filter.getCondition()); return Filter.builder().condition(condition).input(apply(filter.getInput())).build(); } @Override - public Rel visit(org.apache.calcite.rel.core.Calc calc) { + public Rel visit(final org.apache.calcite.rel.core.Calc calc) { return super.visit(calc); } @Override - public Rel visit(org.apache.calcite.rel.core.Project project) { - List expressions = + public Rel visit(final org.apache.calcite.rel.core.Project project) { + final List expressions = project.getProjects().stream() .map(this::toExpression) .collect(java.util.stream.Collectors.toList()); @@ -189,11 +189,11 @@ public Rel visit(org.apache.calcite.rel.core.Project project) { } @Override - public Rel visit(org.apache.calcite.rel.core.Join join) { - Rel left = apply(join.getLeft()); - Rel right = apply(join.getRight()); - Expression condition = toExpression(join.getCondition()); - JoinType joinType = asJoinType(join); + public Rel visit(final org.apache.calcite.rel.core.Join join) { + final Rel left = apply(join.getLeft()); + final Rel right = apply(join.getRight()); + final Expression condition = toExpression(join.getCondition()); + final JoinType joinType = asJoinType(join); // An INNER JOIN with a join condition of TRUE can be encoded as a Substrait Cross relation if (joinType == Join.JoinType.INNER && TRUE.equals(condition)) { @@ -202,8 +202,8 @@ public Rel visit(org.apache.calcite.rel.core.Join join) { return Join.builder().condition(condition).joinType(joinType).left(left).right(right).build(); } - private Join.JoinType asJoinType(org.apache.calcite.rel.core.Join join) { - JoinRelType type = join.getJoinType(); + private Join.JoinType asJoinType(final org.apache.calcite.rel.core.Join join) { + final JoinRelType type = join.getJoinType(); if (type == JoinRelType.INNER) { return Join.JoinType.INNER; @@ -223,7 +223,7 @@ private Join.JoinType asJoinType(org.apache.calcite.rel.core.Join join) { } @Override - public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { + public Rel visit(final org.apache.calcite.rel.core.Correlate correlate) { // left input of correlated-join is similar to the left input of a logical join apply(correlate.getLeft()); @@ -234,64 +234,64 @@ public Rel visit(org.apache.calcite.rel.core.Correlate correlate) { } @Override - public Rel visit(org.apache.calcite.rel.core.Union union) { - List inputs = apply(union.getInputs()); - Set.SetOp setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT; + public Rel visit(final org.apache.calcite.rel.core.Union union) { + final List inputs = apply(union.getInputs()); + final Set.SetOp setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT; return Set.builder().inputs(inputs).setOp(setOp).build(); } @Override - public Rel visit(org.apache.calcite.rel.core.Intersect intersect) { - List inputs = apply(intersect.getInputs()); - Set.SetOp setOp = + public Rel visit(final org.apache.calcite.rel.core.Intersect intersect) { + final List inputs = apply(intersect.getInputs()); + final Set.SetOp setOp = intersect.all ? Set.SetOp.INTERSECTION_MULTISET_ALL : Set.SetOp.INTERSECTION_MULTISET; return Set.builder().inputs(inputs).setOp(setOp).build(); } @Override - public Rel visit(org.apache.calcite.rel.core.Minus minus) { - List inputs = apply(minus.getInputs()); - Set.SetOp setOp = minus.all ? Set.SetOp.MINUS_PRIMARY_ALL : Set.SetOp.MINUS_PRIMARY; + public Rel visit(final org.apache.calcite.rel.core.Minus minus) { + final List inputs = apply(minus.getInputs()); + final Set.SetOp setOp = minus.all ? Set.SetOp.MINUS_PRIMARY_ALL : Set.SetOp.MINUS_PRIMARY; return Set.builder().inputs(inputs).setOp(setOp).build(); } @Override - public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { - Rel input = apply(aggregate.getInput()); - Stream sets; + public Rel visit(final org.apache.calcite.rel.core.Aggregate aggregate) { + final Rel input = apply(aggregate.getInput()); + final Stream sets; if (aggregate.groupSets != null) { sets = aggregate.groupSets.stream(); } else { sets = Stream.of(aggregate.getGroupSet()); } - List groupings = + final List groupings = sets.filter(s -> s != null).map(s -> fromGroupSet(s, input)).collect(Collectors.toList()); // get GROUP_ID() function calls - List groupIdCalls = + final List groupIdCalls = aggregate.getAggCallList().stream() .filter(c -> c.getAggregation().equals(SqlStdOperatorTable.GROUP_ID)) .collect(Collectors.toList()); - List filteredAggCalls = + final List filteredAggCalls = aggregate.getAggCallList().stream() // remove GROUP_ID() function calls .filter(c -> !groupIdCalls.contains(c)) .collect(Collectors.toList()); - List aggCalls = + final List aggCalls = filteredAggCalls.stream() .map(c -> fromAggCall(aggregate.getInput(), input.getRecordType(), c)) .collect(Collectors.toList()); - ImmutableAggregate.Builder builder = + final ImmutableAggregate.Builder builder = Aggregate.builder().input(input).addAllGroupings(groupings).addAllMeasures(aggCalls); if (groupings.size() > 1) { // remove the grouping set index if there was no explicit GROUP_ID() function call if (groupIdCalls.isEmpty()) { - int groupingExprSize = + final int groupingExprSize = Math.toIntExact( groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count()); builder.remap(Remap.offset(0, groupingExprSize + aggCalls.size())); @@ -308,7 +308,7 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { .collect(Collectors.toCollection(ArrayList::new)); for (int i = 0; i < aggregate.getAggCallList().size(); i++) { - AggregateCall aggCall = aggregate.getAggCallList().get(i); + final AggregateCall aggCall = aggregate.getAggCallList().get(i); if (filteredAggCalls.contains(aggCall)) { remap.add( i + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount); @@ -329,22 +329,23 @@ public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { return builder.build(); } - Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { - List references = + Aggregate.Grouping fromGroupSet(final ImmutableBitSet bitSet, final Rel input) { + final List references = bitSet.asList().stream() .map(i -> FieldReference.newInputRelReference(i, input)) .collect(Collectors.toList()); return Aggregate.Grouping.builder().addAllExpressions(references).build(); } - Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) { - Optional invocation = + Aggregate.Measure fromAggCall( + final RelNode input, final Type.Struct inputType, final AggregateCall call) { + final Optional invocation = aggregateFunctionConverter.convert( input, inputType, call, t -> t.accept(rexExpressionConverter)); if (invocation.isEmpty()) { throw new UnsupportedOperationException("Unable to find binding for call " + call); } - Builder builder = Aggregate.Measure.builder().function(invocation.get()); + final Builder builder = Aggregate.Measure.builder().function(invocation.get()); if (call.filterArg != -1) { builder.preMeasureFilter(FieldReference.newRootStructReference(call.filterArg, inputType)); } @@ -354,13 +355,13 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal } @Override - public Rel visit(org.apache.calcite.rel.core.Match match) { + public Rel visit(final org.apache.calcite.rel.core.Match match) { return super.visit(match); } @Override - public Rel visit(org.apache.calcite.rel.core.Sort sort) { - Rel input = apply(sort.getInput()); + public Rel visit(final org.apache.calcite.rel.core.Sort sort) { + final Rel input = apply(sort.getInput()); Rel output = input; // The Calcite Sort relation combines sorting along with offset and fetch/limit @@ -368,7 +369,7 @@ public Rel visit(org.apache.calcite.rel.core.Sort sort) { // Substrait splits this functionality into two different relations: SortRel, FetchRel // Add the SortRel to the relation tree first to match Calcite's application order if (!sort.getCollation().getFieldCollations().isEmpty()) { - List fields = + final List fields = sort.getCollation().getFieldCollations().stream() .map(t -> toSortField(t, input.getRecordType())) .collect(java.util.stream.Collectors.toList()); @@ -376,21 +377,22 @@ public Rel visit(org.apache.calcite.rel.core.Sort sort) { } if (sort.fetch != null || sort.offset != null) { - Long offset = Optional.ofNullable(sort.offset).map(this::asLong).orElse(0L); - OptionalLong count = + final Long offset = Optional.ofNullable(sort.offset).map(this::asLong).orElse(0L); + final OptionalLong count = Optional.ofNullable(sort.fetch) .map(r -> OptionalLong.of(asLong(r))) .orElse(OptionalLong.empty()); - ImmutableFetch.Builder builder = Fetch.builder().input(output).offset(offset).count(count); + final ImmutableFetch.Builder builder = + Fetch.builder().input(output).offset(offset).count(count); output = builder.build(); } return output; } - private long asLong(RexNode rex) { - Expression expr = toExpression(rex); + private long asLong(final RexNode rex) { + final Expression expr = toExpression(rex); if (expr instanceof Expression.I64Literal) { return ((Expression.I64Literal) expr).value(); } else if (expr instanceof Expression.I32Literal) { @@ -400,8 +402,8 @@ private long asLong(RexNode rex) { } public static Expression.SortField toSortField( - RelFieldCollation collation, Type.Struct inputType) { - Expression.SortDirection direction = asSortDirection(collation); + final RelFieldCollation collation, final Type.Struct inputType) { + final Expression.SortDirection direction = asSortDirection(collation); return Expression.SortField.builder() .expr(FieldReference.newRootStructReference(collation.getFieldIndex(), inputType)) @@ -409,8 +411,8 @@ public static Expression.SortField toSortField( .build(); } - private static Expression.SortDirection asSortDirection(RelFieldCollation collation) { - RelFieldCollation.Direction direction = collation.direction; + private static Expression.SortDirection asSortDirection(final RelFieldCollation collation) { + final RelFieldCollation.Direction direction = collation.direction; if (direction == Direction.STRICTLY_ASCENDING || direction == Direction.ASCENDING) { return collation.nullDirection == RelFieldCollation.NullDirection.LAST @@ -428,12 +430,12 @@ private static Expression.SortDirection asSortDirection(RelFieldCollation collat } @Override - public Rel visit(org.apache.calcite.rel.core.Exchange exchange) { + public Rel visit(final org.apache.calcite.rel.core.Exchange exchange) { return super.visit(exchange); } @Override - public Rel visit(TableModify modify) { + public Rel visit(final TableModify modify) { switch (modify.getOperation()) { case INSERT: case DELETE: @@ -459,31 +461,32 @@ public Rel visit(TableModify modify) { { assert modify.getTable() != null; - RelNode input = modify.getInput(); + final RelNode input = modify.getInput(); final Expression condition; if (input instanceof org.apache.calcite.rel.core.Filter) { - org.apache.calcite.rel.core.Filter filter = (org.apache.calcite.rel.core.Filter) input; + final org.apache.calcite.rel.core.Filter filter = + (org.apache.calcite.rel.core.Filter) input; condition = toExpression(filter.getCondition()); } else { condition = Expression.BoolLiteral.builder().nullable(false).value(true).build(); } - List updateColumnNames = modify.getUpdateColumnList(); - List sourceExpressions = getSourceExpressions(modify); - List allTableColumnNames = modify.getTable().getRowType().getFieldNames(); - List transformations = new ArrayList<>(); + final List updateColumnNames = modify.getUpdateColumnList(); + final List sourceExpressions = getSourceExpressions(modify); + final List allTableColumnNames = modify.getTable().getRowType().getFieldNames(); + final List transformations = new ArrayList<>(); for (int i = 0; i < updateColumnNames.size(); i++) { - String colName = updateColumnNames.get(i); - RexNode rexExpr = sourceExpressions.get(i); + final String colName = updateColumnNames.get(i); + final RexNode rexExpr = sourceExpressions.get(i); - int columnIndex = allTableColumnNames.indexOf(colName); + final int columnIndex = allTableColumnNames.indexOf(colName); if (columnIndex == -1) { throw new IllegalStateException( "Updated column '" + colName + "' not found in table schema."); } - Expression substraitExpr = toExpression(rexExpr); + final Expression substraitExpr = toExpression(rexExpr); transformations.add( NamedUpdate.TransformExpression.builder() @@ -505,13 +508,13 @@ public Rel visit(TableModify modify) { } } - private List getSourceExpressions(TableModify modify) { - List results = modify.getSourceExpressionList(); + private List getSourceExpressions(final TableModify modify) { + final List results = modify.getSourceExpressionList(); if (results == null) { return Collections.emptyList(); } - RelNode input = modify.getInput(); + final RelNode input = modify.getInput(); if (input instanceof org.apache.calcite.rel.core.Project) { return resolveProjectedRefs(results, (org.apache.calcite.rel.core.Project) input); } @@ -520,13 +523,13 @@ private List getSourceExpressions(TableModify modify) { } private List resolveProjectedRefs( - List expressions, org.apache.calcite.rel.core.Project project) { - List projects = project.getProjects(); + final List expressions, final org.apache.calcite.rel.core.Project project) { + final List projects = project.getProjects(); return expressions.stream() .map( expression -> { if (expression instanceof RexInputRef) { - int refIndex = ((RexInputRef) expression).getIndex(); + final int refIndex = ((RexInputRef) expression).getIndex(); return projects.get(refIndex); } @@ -540,10 +543,10 @@ private NamedStruct getSchema(final RelNode queryRelRoot) { return typeConverter.toNamedStruct(rowType); } - public Rel handleCreateTable(CreateTable createTable) { - RelNode input = createTable.getInput(); - Rel inputRel = apply(input); - NamedStruct schema = getSchema(input); + public Rel handleCreateTable(final CreateTable createTable) { + final RelNode input = createTable.getInput(); + final Rel inputRel = apply(input); + final NamedStruct schema = getSchema(input); return NamedWrite.builder() .input(inputRel) .tableSchema(schema) @@ -554,9 +557,9 @@ public Rel handleCreateTable(CreateTable createTable) { .build(); } - public Rel handleCreateView(CreateView createView) { - RelNode input = createView.getInput(); - Rel inputRel = apply(input); + public Rel handleCreateView(final CreateView createView) { + final RelNode input = createView.getInput(); + final Rel inputRel = apply(input); final Expression.StructLiteral defaults = ExpressionCreator.struct(false); @@ -571,7 +574,7 @@ public Rel handleCreateView(CreateView createView) { } @Override - public Rel visitOther(RelNode other) { + public Rel visitOther(final RelNode other) { if (other instanceof CreateTable) { return handleCreateTable((CreateTable) other); @@ -582,21 +585,21 @@ public Rel visitOther(RelNode other) { throw new UnsupportedOperationException("Unable to handle node: " + other); } - protected void popFieldAccessDepthMap(RelNode root) { + protected void popFieldAccessDepthMap(final RelNode root) { final OuterReferenceResolver resolver = new OuterReferenceResolver(); resolver.apply(root); fieldAccessDepthMap = resolver.getFieldAccessDepthMap(); } - public Integer getFieldAccessDepth(RexFieldAccess fieldAccess) { + public Integer getFieldAccessDepth(final RexFieldAccess fieldAccess) { return fieldAccessDepthMap.get(fieldAccess); } - public Rel apply(RelNode r) { + public Rel apply(final RelNode r) { return reverseAccept(r); } - public List apply(List inputs) { + public List apply(final List inputs) { return inputs.stream() .map(inputRel -> apply(inputRel)) .collect(java.util.stream.Collectors.toList()); @@ -612,7 +615,8 @@ public List apply(List inputs) { * @param extensions The extension collection to use for the conversion. * @return The resulting Substrait Plan.Root. */ - public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) { + public static Plan.Root convert( + final RelRoot relRoot, final SimpleExtension.ExtensionCollection extensions) { return convert(relRoot, extensions, FEATURES_DEFAULT); } @@ -634,13 +638,14 @@ public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollec * @return The resulting Substrait Plan.Root, containing the converted relational tree and the * output names. */ - public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) { + public static Plan.Root convert(final RelRoot relRoot, final SubstraitRelVisitor visitor) { visitor.popFieldAccessDepthMap(relRoot.rel); - Rel rel = visitor.apply(relRoot.project()); + final Rel rel = visitor.apply(relRoot.project()); // Avoid using the names from relRoot.validatedRowType because if there are // nested types (i.e ROW, MAP, etc) the typeConverter will pad names correctly - List names = visitor.typeConverter.toNamedStruct(relRoot.validatedRowType).names(); + final List names = + visitor.typeConverter.toNamedStruct(relRoot.validatedRowType).names(); return Plan.Root.builder().input(rel).names(names).build(); } @@ -657,7 +662,9 @@ public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) { * @return The resulting Substrait Plan.Root. */ public static Plan.Root convert( - RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + final RelRoot relRoot, + final SimpleExtension.ExtensionCollection extensions, + final FeatureBoard features) { return convert( relRoot, new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features)); @@ -677,7 +684,8 @@ public static Plan.Root convert( * @param extensions The extension collection to use for the conversion. * @return The resulting Substrait Rel. */ - public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) { + public static Rel convert( + final RelNode relNode, final SimpleExtension.ExtensionCollection extensions) { return convert(relNode, extensions, FEATURES_DEFAULT); } @@ -695,7 +703,7 @@ public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection e * behavior. * @return The resulting Substrait Rel. */ - public static Rel convert(RelNode relNode, SubstraitRelVisitor visitor) { + public static Rel convert(final RelNode relNode, final SubstraitRelVisitor visitor) { visitor.popFieldAccessDepthMap(relNode); return visitor.apply(relNode); } @@ -716,7 +724,9 @@ public static Rel convert(RelNode relNode, SubstraitRelVisitor visitor) { * @return The resulting Substrait Rel. */ public static Rel convert( - RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { + final RelNode relNode, + final SimpleExtension.ExtensionCollection extensions, + final FeatureBoard features) { return convert( relNode, new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features)); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 8dcfbf9e0..7073437ba 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -40,29 +40,29 @@ public class SubstraitToCalcite { protected final Prepare.CatalogReader catalogReader; public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { + final SimpleExtension.ExtensionCollection extensions, final RelDataTypeFactory typeFactory) { this(extensions, typeFactory, TypeConverter.DEFAULT, null); } public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - Prepare.CatalogReader catalogReader) { + final SimpleExtension.ExtensionCollection extensions, + final RelDataTypeFactory typeFactory, + final Prepare.CatalogReader catalogReader) { this(extensions, typeFactory, TypeConverter.DEFAULT, catalogReader); } public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final SimpleExtension.ExtensionCollection extensions, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { this(extensions, typeFactory, typeConverter, null); } public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter, - Prepare.CatalogReader catalogReader) { + final SimpleExtension.ExtensionCollection extensions, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter, + final Prepare.CatalogReader catalogReader) { this.extensions = extensions; this.typeFactory = typeFactory; this.typeConverter = typeConverter; @@ -74,8 +74,8 @@ public SubstraitToCalcite( * *

    Override this method to customize schema extraction. */ - protected CalciteSchema toSchema(Rel rel) { - SchemaCollector schemaCollector = new SchemaCollector(typeFactory, typeConverter); + protected CalciteSchema toSchema(final Rel rel) { + final SchemaCollector schemaCollector = new SchemaCollector(typeFactory, typeConverter); return schemaCollector.toSchema(rel); } @@ -84,7 +84,7 @@ protected CalciteSchema toSchema(Rel rel) { * *

    Override this method to customize the {@link RelBuilder}. */ - protected RelBuilder createRelBuilder(CalciteSchema schema) { + protected RelBuilder createRelBuilder(final CalciteSchema schema) { return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(schema.plus()).build()); } @@ -93,7 +93,7 @@ protected RelBuilder createRelBuilder(CalciteSchema schema) { * *

    Override this method to customize the {@link SubstraitRelNodeConverter}. */ - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { + protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(final RelBuilder relBuilder) { return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder); } @@ -107,15 +107,15 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r * @param rel {@link Rel} to convert * @return {@link RelNode} */ - public RelNode convert(Rel rel) { - RelBuilder relBuilder; + public RelNode convert(final Rel rel) { + final RelBuilder relBuilder; if (catalogReader != null) { relBuilder = createRelBuilder(catalogReader.getRootSchema()); } else { - CalciteSchema rootSchema = toSchema(rel); + final CalciteSchema rootSchema = toSchema(rel); relBuilder = createRelBuilder(rootSchema); } - SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); + final SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); return rel.accept(converter, Context.newContext()); } @@ -134,8 +134,8 @@ public RelNode convert(Rel rel) { * @param root {@link Plan.Root} to convert * @return {@link RelRoot} */ - public RelRoot convert(Plan.Root root) { - RelNode convertedNode = convert(root.getInput()); + public RelRoot convert(final Plan.Root root) { + final RelNode convertedNode = convert(root.getInput()); if (convertedNode instanceof TableModify) { final TableModify tableModify = (TableModify) convertedNode; @@ -158,10 +158,10 @@ public RelRoot convert(Plan.Root root) { } return RelRoot.of(tableModify, tableRowType, kind); } - SqlKindFromRel sqlKindFromRel = new SqlKindFromRel(); - SqlKind kind = root.getInput().accept(sqlKindFromRel, EmptyVisitationContext.INSTANCE); - RelDataType inputRowType = convertedNode.getRowType(); - RelDataType newRowType = renameFields(inputRowType, root.getNames(), 0).right; + final SqlKindFromRel sqlKindFromRel = new SqlKindFromRel(); + final SqlKind kind = root.getInput().accept(sqlKindFromRel, EmptyVisitationContext.INSTANCE); + final RelDataType inputRowType = convertedNode.getRowType(); + final RelDataType newRowType = renameFields(inputRowType, root.getNames(), 0).right; return RelRoot.of(convertedNode, newRowType, kind); } @@ -175,7 +175,7 @@ public RelRoot convert(Plan.Root root) { * @return the renamed {@link RelDataType} */ private Pair renameFields( - RelDataType type, List names, Integer currentIndex) { + final RelDataType type, final List names, final Integer currentIndex) { Integer nextIndex = currentIndex; switch (type.getSqlTypeName()) { @@ -183,9 +183,10 @@ private Pair renameFields( case STRUCTURED: final List newFieldNames = new ArrayList<>(); final List renamedFields = new ArrayList<>(); - for (RelDataTypeField field : type.getFieldList()) { + for (final RelDataTypeField field : type.getFieldList()) { newFieldNames.add(names.get(nextIndex)); - Pair p = renameFields(field.getType(), names, (nextIndex + 1)); + final Pair p = + renameFields(field.getType(), names, (nextIndex + 1)); renamedFields.add(p.right); nextIndex = p.left; } @@ -195,15 +196,15 @@ private Pair renameFields( typeFactory.createStructType(type.getStructKind(), renamedFields, newFieldNames)); case ARRAY: case MULTISET: - Pair renamedElementType = + final Pair renamedElementType = renameFields(type.getComponentType(), names, nextIndex); return Pair.of( renamedElementType.left, typeFactory.createArrayType(renamedElementType.right, -1L)); case MAP: - Pair renamedKeyType = + final Pair renamedKeyType = renameFields(type.getKeyType(), names, nextIndex); - Pair renamedValueType = + final Pair renamedValueType = renameFields(type.getValueType(), names, renamedKeyType.left); return Pair.of( @@ -222,17 +223,17 @@ private NamedStructGatherer() { this.tableMap = new HashMap<>(); } - public static Map, NamedStruct> gatherTables(Rel rel) { - NamedStructGatherer visitor = new NamedStructGatherer(); + public static Map, NamedStruct> gatherTables(final Rel rel) { + final NamedStructGatherer visitor = new NamedStructGatherer(); rel.accept(visitor, EmptyVisitationContext.INSTANCE); return visitor.tableMap; } @Override - public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { - Optional result = super.visit(namedScan, context); + public Optional visit(final NamedScan namedScan, final EmptyVisitationContext context) { + final Optional result = super.visit(namedScan, context); - List tableName = namedScan.getNames(); + final List tableName = namedScan.getNames(); tableMap.put(tableName, namedScan.getInitialSchema()); return result; diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index 421b45317..8cc720246 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -10,7 +10,7 @@ public SubstraitToSql() { super(FEATURES_DEFAULT); } - public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) { + public RelNode substraitRelToCalciteRel(final Rel relRoot, final Prepare.CatalogReader catalog) { return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalog, parserConfig); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java index 932b8f6d8..6f9904d12 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/TypeConverter.java @@ -30,43 +30,43 @@ public class TypeConverter { new UserTypeMapper() { @Nullable @Override - public Type toSubstrait(RelDataType relDataType) { + public Type toSubstrait(final RelDataType relDataType) { return null; } @Nullable @Override - public RelDataType toCalcite(Type.UserDefined type) { + public RelDataType toCalcite(final Type.UserDefined type) { return null; } }); - public TypeConverter(UserTypeMapper userTypeMapper) { + public TypeConverter(final UserTypeMapper userTypeMapper) { this.userTypeMapper = userTypeMapper; } - public Type toSubstrait(RelDataType type) { + public Type toSubstrait(final RelDataType type) { return toSubstrait(type, new ArrayList<>()); } - public NamedStruct toNamedStruct(RelDataType type) { + public NamedStruct toNamedStruct(final RelDataType type) { if (type.getSqlTypeName() != SqlTypeName.ROW) { throw new IllegalArgumentException("Expected type of struct."); } - ArrayList names = new ArrayList(); - Struct struct = (Type.Struct) toSubstrait(type, names); + final ArrayList names = new ArrayList(); + final Struct struct = (Type.Struct) toSubstrait(type, names); return NamedStruct.of(names, struct); } - private Type toSubstrait(RelDataType type, List names) { + private Type toSubstrait(final RelDataType type, final List names) { // Check for user mapped types first as they may re-use SqlTypeNames - Type userType = userTypeMapper.toSubstrait(type); + final Type userType = userTypeMapper.toSubstrait(type); if (userType != null) { return userType; } - TypeCreator creator = Type.withNullability(type.isNullable()); + final TypeCreator creator = Type.withNullability(type.isNullable()); switch (type.getSqlTypeName()) { case BOOLEAN: @@ -132,14 +132,14 @@ private Type toSubstrait(RelDataType type, List names) { return creator.fixedBinary(type.getPrecision()); case MAP: { - MapSqlType map = (MapSqlType) type; + final MapSqlType map = (MapSqlType) type; return creator.map( toSubstrait(map.getKeyType(), names), toSubstrait(map.getValueType(), names)); } case ROW: { - ArrayList children = new ArrayList(); - for (RelDataTypeField field : type.getFieldList()) { + final ArrayList children = new ArrayList(); + for (final RelDataTypeField field : type.getFieldList()) { names.add(field.getName()); children.add(toSubstrait(field.getType(), names)); } @@ -154,14 +154,14 @@ private Type toSubstrait(RelDataType type, List names) { } public RelDataType toCalcite( - RelDataTypeFactory relDataTypeFactory, TypeExpression typeExpression) { + final RelDataTypeFactory relDataTypeFactory, final TypeExpression typeExpression) { return toCalcite(relDataTypeFactory, typeExpression, null); } public RelDataType toCalcite( - RelDataTypeFactory relDataTypeFactory, - TypeExpression typeExpression, - List dfsFieldNames) { + final RelDataTypeFactory relDataTypeFactory, + final TypeExpression typeExpression, + final List dfsFieldNames) { return typeExpression.accept( new ToRelDataType(relDataTypeFactory, userTypeMapper, dfsFieldNames, 0)); } @@ -179,7 +179,7 @@ public ToRelDataType( final RelDataTypeFactory type, final UserTypeMapper userTypeMapper, final List fieldNames, - int fieldNamePosition) { + final int fieldNamePosition) { super("Unknown expression type."); this.typeFactory = type; this.userTypeMapper = userTypeMapper; @@ -188,73 +188,73 @@ public ToRelDataType( } @Override - public RelDataType visit(Type.Bool expr) { + public RelDataType visit(final Type.Bool expr) { return t(n(expr), SqlTypeName.BOOLEAN); } @Override - public RelDataType visit(Type.I8 expr) { + public RelDataType visit(final Type.I8 expr) { return t(n(expr), SqlTypeName.TINYINT); } @Override - public RelDataType visit(Type.I16 expr) { + public RelDataType visit(final Type.I16 expr) { return t(n(expr), SqlTypeName.SMALLINT); } @Override - public RelDataType visit(Type.I32 expr) { + public RelDataType visit(final Type.I32 expr) { return t(n(expr), SqlTypeName.INTEGER); } @Override - public RelDataType visit(Type.I64 expr) { + public RelDataType visit(final Type.I64 expr) { return t(n(expr), SqlTypeName.BIGINT); } @Override - public RelDataType visit(Type.FP32 expr) { + public RelDataType visit(final Type.FP32 expr) { return t(n(expr), SqlTypeName.REAL); } @Override - public RelDataType visit(Type.FP64 expr) { + public RelDataType visit(final Type.FP64 expr) { return t(n(expr), SqlTypeName.DOUBLE); } @Override - public RelDataType visit(Type.Str expr) { + public RelDataType visit(final Type.Str expr) { return t(n(expr), SqlTypeName.VARCHAR); } @Override - public RelDataType visit(Type.Binary expr) { + public RelDataType visit(final Type.Binary expr) { return t(n(expr), SqlTypeName.VARBINARY); } @Override - public RelDataType visit(Type.Date expr) { + public RelDataType visit(final Type.Date expr) { return t(n(expr), SqlTypeName.DATE); } @Override - public RelDataType visit(Type.Time expr) { + public RelDataType visit(final Type.Time expr) { return t(n(expr), SqlTypeName.TIME, 6); } @Override - public RelDataType visit(Type.TimestampTZ expr) { + public RelDataType visit(final Type.TimestampTZ expr) { return t(n(expr), SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, 6); } @Override - public RelDataType visit(Type.Timestamp expr) { + public RelDataType visit(final Type.Timestamp expr) { return t(n(expr), SqlTypeName.TIMESTAMP, 6); } @Override - public RelDataType visit(Type.PrecisionTime expr) { - int maxPrecision = typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIME); + public RelDataType visit(final Type.PrecisionTime expr) { + final int maxPrecision = typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIME); if (expr.precision() > maxPrecision) { throw new UnsupportedOperationException( String.format( @@ -265,8 +265,8 @@ public RelDataType visit(Type.PrecisionTime expr) { } @Override - public RelDataType visit(Type.PrecisionTimestamp expr) { - int maxPrecision = typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIMESTAMP); + public RelDataType visit(final Type.PrecisionTimestamp expr) { + final int maxPrecision = typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIMESTAMP); if (expr.precision() > maxPrecision) { throw new UnsupportedOperationException( String.format( @@ -277,8 +277,8 @@ public RelDataType visit(Type.PrecisionTimestamp expr) { } @Override - public RelDataType visit(Type.PrecisionTimestampTZ expr) throws RuntimeException { - int maxPrecision = + public RelDataType visit(final Type.PrecisionTimestampTZ expr) throws RuntimeException { + final int maxPrecision = typeFactory.getTypeSystem().getMaxPrecision(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE); if (expr.precision() > maxPrecision) { throw new UnsupportedOperationException( @@ -290,51 +290,51 @@ public RelDataType visit(Type.PrecisionTimestampTZ expr) throws RuntimeException } @Override - public RelDataType visit(Type.IntervalYear expr) { + public RelDataType visit(final Type.IntervalYear expr) { return typeFactory.createTypeWithNullability( typeFactory.createSqlIntervalType(YEAR_MONTH_INTERVAL), n(expr)); } @Override - public RelDataType visit(Type.IntervalDay expr) { + public RelDataType visit(final Type.IntervalDay expr) { return typeFactory.createTypeWithNullability( typeFactory.createSqlIntervalType(DAY_SECOND_INTERVAL), n(expr)); } @Override - public RelDataType visit(Type.FixedChar expr) { + public RelDataType visit(final Type.FixedChar expr) { return t(n(expr), SqlTypeName.CHAR, expr.length()); } @Override - public RelDataType visit(Type.VarChar expr) { + public RelDataType visit(final Type.VarChar expr) { return t(n(expr), SqlTypeName.VARCHAR, expr.length()); } @Override - public RelDataType visit(Type.FixedBinary expr) { + public RelDataType visit(final Type.FixedBinary expr) { return t(n(expr), SqlTypeName.BINARY, expr.length()); } @Override - public RelDataType visit(Type.Decimal expr) { + public RelDataType visit(final Type.Decimal expr) { return t(n(expr), SqlTypeName.DECIMAL, expr.precision(), expr.scale()); } @Override - public RelDataType visit(Type.Struct expr) { + public RelDataType visit(final Type.Struct expr) { if (withinStruct) { throw new IllegalStateException("Visitor can't be re-used for nested structs."); } withinStruct = true; try { - List fieldTypes = new ArrayList<>(); - List localFieldNames = new ArrayList<>(); - for (TypeExpression field : expr.fields()) { + final List fieldTypes = new ArrayList<>(); + final List localFieldNames = new ArrayList<>(); + for (final TypeExpression field : expr.fields()) { localFieldNames.add( fieldNames == null ? "f" + fieldNamePosition : fieldNames.get(fieldNamePosition)); fieldNamePosition++; - ToRelDataType childVisitor = + final ToRelDataType childVisitor = new ToRelDataType(typeFactory, userTypeMapper, fieldNames, fieldNamePosition); fieldTypes.add(field.accept(childVisitor)); fieldNamePosition = childVisitor.fieldNamePosition; @@ -348,18 +348,18 @@ public RelDataType visit(Type.Struct expr) { } @Override - public RelDataType visit(Type.ListType expr) { + public RelDataType visit(final Type.ListType expr) { return n(expr, typeFactory.createArrayType(expr.elementType().accept(this), -1)); } @Override - public RelDataType visit(Type.Map expr) { + public RelDataType visit(final Type.Map expr) { return n(expr, typeFactory.createMapType(expr.key().accept(this), expr.value().accept(this))); } @Override - public RelDataType visit(Type.UserDefined expr) throws RuntimeException { - RelDataType type = userTypeMapper.toCalcite(expr); + public RelDataType visit(final Type.UserDefined expr) throws RuntimeException { + final RelDataType type = userTypeMapper.toCalcite(expr); if (type != null) { return type; } @@ -367,11 +367,12 @@ public RelDataType visit(Type.UserDefined expr) throws RuntimeException { String.format("Unable to map user-defined type: %s", expr)); } - private boolean n(NullableType type) { + private boolean n(final NullableType type) { return type.nullable(); } - private RelDataType t(boolean nullable, SqlTypeName typeName, Integer... props) { + private RelDataType t( + final boolean nullable, final SqlTypeName typeName, final Integer... props) { final RelDataType baseType; if (props.length == 0) { baseType = typeFactory.createSqlType(typeName); @@ -387,7 +388,7 @@ private RelDataType t(boolean nullable, SqlTypeName typeName, Integer... props) return typeFactory.createTypeWithNullability(baseType, nullable); } - private RelDataType n(Type substraitType, RelDataType type) { + private RelDataType n(final Type substraitType, final RelDataType type) { return typeFactory.createTypeWithNullability(type, n(substraitType)); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/Utils.java b/isthmus/src/main/java/io/substrait/isthmus/Utils.java index 3382007f1..fc8796da3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/Utils.java +++ b/isthmus/src/main/java/io/substrait/isthmus/Utils.java @@ -27,10 +27,10 @@ public static Stream> crossProduct(List> lists) { * @param element 1 * @return [a, b, 1] */ - BiFunction, T, List> appendElementToList = + final BiFunction, T, List> appendElementToList = (list, element) -> { - int capacity = list.size() + 1; - ArrayList newList = new ArrayList<>(capacity); + final int capacity = list.size() + 1; + final ArrayList newList = new ArrayList<>(capacity); newList.addAll(list); newList.add(element); return unmodifiableList(newList); @@ -39,12 +39,12 @@ public static Stream> crossProduct(List> lists) { /* * ([a, b], [1, 2]) -> [a, b, 1], [a, b, 2] */ - BiFunction, List, Stream>> appendAndGen = + final BiFunction, List, Stream>> appendAndGen = (list, elemsToAppend) -> elemsToAppend.stream().map(element -> appendElementToList.apply(list, element)); /** ([[a, b], [c, d]], [1, 2]) -> [a, b, 1], [a, b, 2], [c, d, 1], [c, d, 2] */ - BiFunction>, List, Stream>> appendAndGenLists = + final BiFunction>, List, Stream>> appendAndGenLists = (products, toJoin) -> products.flatMap(product -> appendAndGen.apply(product, toJoin)); if (lists.isEmpty()) { @@ -52,8 +52,8 @@ public static Stream> crossProduct(List> lists) { } lists = new ArrayList<>(lists); - List firstListToJoin = lists.remove(0); - Stream> startProduct = appendAndGen.apply(new ArrayList(), firstListToJoin); + final List firstListToJoin = lists.remove(0); + final Stream> startProduct = appendAndGen.apply(new ArrayList(), firstListToJoin); return lists.stream() // .filter(Objects::nonNull) // diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitOperatorTable.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitOperatorTable.java index 259ae5b1c..ea922b71d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitOperatorTable.java +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitOperatorTable.java @@ -71,11 +71,11 @@ private SubstraitOperatorTable() {} @Override public void lookupOperatorOverloads( - SqlIdentifier opName, - @Nullable SqlFunctionCategory category, - SqlSyntax syntax, - List operatorList, - SqlNameMatcher nameMatcher) { + final SqlIdentifier opName, + @Nullable final SqlFunctionCategory category, + final SqlSyntax syntax, + final List operatorList, + final SqlNameMatcher nameMatcher) { SUBSTRAIT_OPERATOR_TABLE.lookupOperatorOverloads( opName, category, syntax, operatorList, nameMatcher); if (!operatorList.isEmpty()) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java index 8530fe661..37b767fea 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitSchema.java @@ -20,7 +20,7 @@ public SubstraitSchema() { this.schemaMap = new HashMap<>(); } - public SubstraitSchema(Map tableMap) { + public SubstraitSchema(final Map tableMap) { this.tableMap = tableMap; this.schemaMap = new HashMap<>(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java index f642c73d8..8f0067932 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/SubstraitTable.java @@ -10,7 +10,7 @@ public class SubstraitTable extends AbstractTable { private final RelDataType rowType; private final String tableName; - public SubstraitTable(String tableName, RelDataType rowType) { + public SubstraitTable(final String tableName, final RelDataType rowType) { this.tableName = tableName; this.rowType = rowType; } @@ -20,7 +20,7 @@ public String getName() { } @Override - public RelDataType getRowType(RelDataTypeFactory typeFactory) { + public RelDataType getRowType(final RelDataTypeFactory typeFactory) { return rowType; } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java index 66a030b8b..9f027dd78 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java @@ -11,7 +11,7 @@ public class CreateTable extends AbstractRelNode { private final List tableName; private final RelNode input; - public CreateTable(List tableName, RelNode input) { + public CreateTable(final List tableName, final RelNode input) { super(input.getCluster(), input.getTraitSet()); this.tableName = tableName; @@ -24,7 +24,7 @@ protected RelDataType deriveRowType() { } @Override - public RelWriter explainTerms(RelWriter pw) { + public RelWriter explainTerms(final RelWriter pw) { return super.explainTerms(pw).input("input", getInput()).item("tableName", getTableName()); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java index ef1e228cb..1cbd454cc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java @@ -10,7 +10,7 @@ public class CreateView extends AbstractRelNode { private final List viewName; private final RelNode input; - public CreateView(List viewName, RelNode input) { + public CreateView(final List viewName, final RelNode input) { super(input.getCluster(), input.getTraitSet()); this.viewName = viewName; this.input = input; @@ -22,7 +22,7 @@ protected RelDataType deriveRowType() { } @Override - public RelWriter explainTerms(RelWriter pw) { + public RelWriter explainTerms(final RelWriter pw) { return super.explainTerms(pw).input("input", getInput()).item("viewName", getViewName()); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java index 6a237b366..ec9987dfd 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java @@ -30,7 +30,7 @@ private Function findDdlHandler(final SqlCall call) { return null; } - public DdlSqlToRelConverter(SqlToRelConverter converter) { + public DdlSqlToRelConverter(final SqlToRelConverter converter) { this.converter = converter; ddlHandlers.put(SqlCreateTable.class, sqlCall -> handleCreateTable((SqlCreateTable) sqlCall)); @@ -38,8 +38,8 @@ public DdlSqlToRelConverter(SqlToRelConverter converter) { } @Override - public RelRoot visit(SqlCall sqlCall) { - Function ddlHandler = findDdlHandler(sqlCall); + public RelRoot visit(final SqlCall sqlCall) { + final Function ddlHandler = findDdlHandler(sqlCall); if (ddlHandler != null) { return ddlHandler.apply(sqlCall); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 8d81b0b00..8904948d6 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -36,33 +36,34 @@ protected ImmutableList getSigs() { } public AggregateFunctionConverter( - List functions, RelDataTypeFactory typeFactory) { + final List functions, + final RelDataTypeFactory typeFactory) { super(functions, typeFactory); } public AggregateFunctionConverter( - List functions, - List additionalSignatures, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final List functions, + final List additionalSignatures, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { super(functions, additionalSignatures, typeFactory, typeConverter); } @Override protected AggregateFunctionInvocation generateBinding( - WrappedAggregateCall call, - SimpleExtension.AggregateFunctionVariant function, - List arguments, - Type outputType) { - AggregateCall agg = call.getUnderlying(); + final WrappedAggregateCall call, + final SimpleExtension.AggregateFunctionVariant function, + final List arguments, + final Type outputType) { + final AggregateCall agg = call.getUnderlying(); - List sorts = + final List sorts = agg.getCollation() != null ? agg.getCollation().getFieldCollations().stream() .map(r -> SubstraitRelVisitor.toSortField(r, call.inputType)) .collect(java.util.stream.Collectors.toList()) : Collections.emptyList(); - Expression.AggregationInvocation invocation = + final Expression.AggregationInvocation invocation = agg.isDistinct() ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; @@ -76,12 +77,12 @@ protected AggregateFunctionInvocation generateBinding( } public Optional convert( - RelNode input, - Type.Struct inputType, - AggregateCall call, - Function topLevelConverter) { + final RelNode input, + final Type.Struct inputType, + final AggregateCall call, + final Function topLevelConverter) { - FunctionFinder m = getFunctionFinder(call); + final FunctionFinder m = getFunctionFinder(call); if (m == null) { return Optional.empty(); } @@ -89,11 +90,12 @@ public Optional convert( return Optional.empty(); } - WrappedAggregateCall wrapped = new WrappedAggregateCall(call, input, rexBuilder, inputType); + final WrappedAggregateCall wrapped = + new WrappedAggregateCall(call, input, rexBuilder, inputType); return m.attemptMatch(wrapped, topLevelConverter); } - protected FunctionFinder getFunctionFinder(AggregateCall call) { + protected FunctionFinder getFunctionFinder(final AggregateCall call) { // replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT // before converting into substrait function SqlAggFunction aggFunction = call.getAggregation(); @@ -101,7 +103,7 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) { aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; } - SqlAggFunction lookupFunction = + final SqlAggFunction lookupFunction = // Replace default Calcite aggregate calls with Substrait specific variants. // See toSubstraitAggVariant for more details. AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); @@ -115,7 +117,10 @@ static class WrappedAggregateCall implements FunctionConverter.GenericCall { private final Type.Struct inputType; private WrappedAggregateCall( - AggregateCall call, RelNode input, RexBuilder rexBuilder, Type.Struct inputType) { + final AggregateCall call, + final RelNode input, + final RexBuilder rexBuilder, + final Type.Struct inputType) { this.call = call; this.input = input; this.rexBuilder = rexBuilder; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..72691d91b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -22,7 +22,7 @@ public class CallConverters { public static Function CAST = typeConverter -> (call, visitor) -> { - Expression.FailureBehavior failureBehavior; + final Expression.FailureBehavior failureBehavior; switch (call.getKind()) { case CAST: failureBehavior = Expression.FailureBehavior.THROW_EXCEPTION; @@ -60,15 +60,15 @@ public class CallConverters { if (call.getKind() != SqlKind.REINTERPRET) { return null; } - Expression operand = visitor.apply(call.getOperands().get(0)); - Type type = typeConverter.toSubstrait(call.getType()); + final Expression operand = visitor.apply(call.getOperands().get(0)); + final Type type = typeConverter.toSubstrait(call.getType()); // For now, we only support handling of SqlKind.REINTEPRETET for the case of stored // user-defined literals if (operand instanceof Expression.FixedBinaryLiteral && type instanceof Type.UserDefined) { - Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; - Type.UserDefined t = (Type.UserDefined) type; + final Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; + final Type.UserDefined t = (Type.UserDefined) type; return Expression.UserDefinedLiteral.builder() .urn(t.urn()) @@ -100,12 +100,12 @@ public class CallConverters { // else) assert call.getOperands().size() % 2 == 1; - List caseArgs = + final List caseArgs = call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList()); - int last = caseArgs.size() - 1; + final int last = caseArgs.size() - 1; // for if/else, process in reverse to maintain query order - List caseConditions = new ArrayList<>(); + final List caseConditions = new ArrayList<>(); for (int i = 0; i < last; i += 2) { caseConditions.add( Expression.IfClause.builder() @@ -114,7 +114,7 @@ public class CallConverters { .build()); } - Expression defaultResult = caseArgs.get(last); + final Expression defaultResult = caseArgs.get(last); return ExpressionCreator.ifThenStatement(defaultResult, caseConditions); }; @@ -124,18 +124,18 @@ public class CallConverters { * RexProgram, RexNode)} */ public static Function CREATE_SEARCH_CONV = - (RexBuilder rexBuilder) -> - (RexCall call, Function visitor) -> { + (final RexBuilder rexBuilder) -> + (final RexCall call, final Function visitor) -> { if (call.getKind() != SqlKind.SEARCH) { return null; } else { - RexNode expandSearch = RexUtil.expandSearch(rexBuilder, null, call); + final RexNode expandSearch = RexUtil.expandSearch(rexBuilder, null, call); // if no expansion happened, avoid infinite recursion. return expandSearch.equals(call) ? null : visitor.apply(expandSearch); } }; - public static List defaults(TypeConverter typeConverter) { + public static List defaults(final TypeConverter typeConverter) { return ImmutableList.of( new FieldSelectionConverter(typeConverter), CallConverters.CASE, @@ -150,7 +150,7 @@ public interface SimpleCallConverter extends CallConverter { @Override default Optional convert( - RexCall call, Function topLevelConverter) { + final RexCall call, final Function topLevelConverter) { return Optional.ofNullable(apply(call, topLevelConverter)); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index b69ef9b02..2fedd4b65 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -75,7 +75,7 @@ public class EnumConverter { } private static Optional> constructValue( - Class> cls, Supplier> option) { + final Class> cls, final Supplier> option) { if (cls.isAssignableFrom(TimeUnitRange.class)) { return option.get().map(TimeUnitRange::valueOf); } @@ -88,12 +88,15 @@ private static Optional> constructValue( } static Optional toRex( - RexBuilder rexBuilder, SimpleExtension.Function fnDef, int argIdx, EnumArg e) { - ArgAnchor aAnch = argAnchor(fnDef, argIdx); - Optional>> v = + final RexBuilder rexBuilder, + final SimpleExtension.Function fnDef, + final int argIdx, + final EnumArg e) { + final ArgAnchor aAnch = argAnchor(fnDef, argIdx); + final Optional>> v = Optional.ofNullable(calciteEnumMap.getOrDefault(aAnch, null)); - Supplier> sOptionVal = + final Supplier> sOptionVal = () -> { if (e.value().isPresent()) { return Optional.of(e.value().get()); @@ -106,17 +109,17 @@ static Optional toRex( } private static Optional findEnumArg( - SimpleExtension.Function function, ArgAnchor enumAnchor) { + final SimpleExtension.Function function, final ArgAnchor enumAnchor) { if (enumAnchor.fn == function.getAnchor()) { return Optional.empty(); } else { - List args = function.args(); + final List args = function.args(); if (args.size() <= enumAnchor.argIdx) { return Optional.empty(); } - Argument arg = args.get(enumAnchor.argIdx); + final Argument arg = args.get(enumAnchor.argIdx); if (arg instanceof SimpleExtension.EnumArgument) { return Optional.of((SimpleExtension.EnumArgument) arg); } else { @@ -126,14 +129,14 @@ private static Optional findEnumArg( } static Optional fromRex( - SimpleExtension.Function function, RexLiteral literal, int argIdx) { + final SimpleExtension.Function function, final RexLiteral literal, final int argIdx) { switch (literal.getType().getSqlTypeName()) { case SYMBOL: { - Object v = literal.getValue(); + final Object v = literal.getValue(); if (!literal.isNull() && (v instanceof Enum)) { - Enum value = (Enum) v; - ArgAnchor enumAnchor = argAnchor(function, argIdx); + final Enum value = (Enum) v; + final ArgAnchor enumAnchor = argAnchor(function, argIdx); return findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name())); } @@ -144,11 +147,11 @@ static Optional fromRex( } } - static boolean canConvert(Enum value) { + static boolean canConvert(final Enum value) { return value != null && calciteEnumMap.containsValue(value.getClass()); } - static boolean isEnumValue(RexNode value) { + static boolean isEnumValue(final RexNode value) { return value instanceof RexLiteral && value.getType().getSqlTypeName() == SqlTypeName.SYMBOL; } @@ -167,23 +170,23 @@ public int hashCode() { } @Override - public boolean equals(Object obj) { + public boolean equals(final Object obj) { if (this == obj) { return true; } if (!(obj instanceof ArgAnchor)) { return false; } - ArgAnchor other = (ArgAnchor) obj; + final ArgAnchor other = (ArgAnchor) obj; return Objects.equals(fn, other.fn) && argIdx == other.argIdx; } } - private static ArgAnchor argAnchor(String fnNS, String fnSig, int argIdx) { + private static ArgAnchor argAnchor(final String fnNS, final String fnSig, final int argIdx) { return new ArgAnchor(SimpleExtension.FunctionAnchor.of(fnNS, fnSig), argIdx); } - private static ArgAnchor argAnchor(SimpleExtension.Function fnDef, int argIdx) { + private static ArgAnchor argAnchor(final SimpleExtension.Function fnDef, final int argIdx) { return new ArgAnchor( SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().urn(), fnDef.getAnchor().key()), argIdx); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..aa8eaac10 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -88,10 +88,10 @@ public class ExpressionRexConverter protected SubstraitRelNodeConverter relNodeConverter; public ExpressionRexConverter( - RelDataTypeFactory typeFactory, - ScalarFunctionConverter scalarFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter) { + final RelDataTypeFactory typeFactory, + final ScalarFunctionConverter scalarFunctionConverter, + final WindowFunctionConverter windowFunctionConverter, + final TypeConverter typeConverter) { this.typeFactory = typeFactory; this.typeConverter = typeConverter; this.rexBuilder = new RexBuilder(typeFactory); @@ -104,79 +104,90 @@ public void setRelNodeConverter(final SubstraitRelNodeConverter substraitRelNode } @Override - public RexNode visit(Expression.NullLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.NullLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral(null, typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.UserDefinedLiteral expr, Context context) + public RexNode visit(final Expression.UserDefinedLiteral expr, final Context context) throws RuntimeException { - RexLiteral binaryLiteral = + final RexLiteral binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); - RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); + final RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); } @Override - public RexNode visit(Expression.BoolLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.BoolLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); } @Override - public RexNode visit(Expression.I8Literal expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.I8Literal expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.I16Literal expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.I16Literal expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.I32Literal expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.I32Literal expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.I64Literal expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.I64Literal expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.FP32Literal expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.FP32Literal expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.FP64Literal expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.FP64Literal expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.FixedCharLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.FixedCharLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); } @Override - public RexNode visit(Expression.StrLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.StrLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType()), true); } @Override - public RexNode visit(Expression.VarCharLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.VarCharLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType()), true); } @Override - public RexNode visit(Expression.FixedBinaryLiteral expr, Context context) + public RexNode visit(final Expression.FixedBinaryLiteral expr, final Context context) throws RuntimeException { return rexBuilder.makeLiteral( new ByteString(expr.value().toByteArray()), @@ -185,7 +196,8 @@ public RexNode visit(Expression.FixedBinaryLiteral expr, Context context) } @Override - public RexNode visit(Expression.BinaryLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.BinaryLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( new ByteString(expr.value().toByteArray()), typeConverter.toCalcite(typeFactory, expr.getType()), @@ -193,90 +205,96 @@ public RexNode visit(Expression.BinaryLiteral expr, Context context) throws Runt } @Override - public RexNode visit(Expression.TimeLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.TimeLiteral expr, final Context context) + throws RuntimeException { // Expression.TimeLiteral is Microseconds // Construct a TimeString : // 1. Truncate microseconds to seconds // 2. Get the fraction seconds in precision of nanoseconds. // 3. Construct TimeString : seconds + fraction_seconds part. - long microSec = expr.value(); - long seconds = TimeUnit.MICROSECONDS.toSeconds(microSec); - int fracSecondsInNano = + final long microSec = expr.value(); + final long seconds = TimeUnit.MICROSECONDS.toSeconds(microSec); + final int fracSecondsInNano = (int) (TimeUnit.MICROSECONDS.toNanos(microSec) - TimeUnit.SECONDS.toNanos(seconds)); - TimeString timeString = + final TimeString timeString = TimeString.fromMillisOfDay((int) TimeUnit.SECONDS.toMillis(seconds)) .withNanos(fracSecondsInNano); return rexBuilder.makeLiteral(timeString, typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(SingleOrList expr, Context context) throws RuntimeException { - RexNode lhs = expr.condition().accept(this, context); + public RexNode visit(final SingleOrList expr, final Context context) throws RuntimeException { + final RexNode lhs = expr.condition().accept(this, context); return rexBuilder.makeIn( lhs, expr.options().stream().map(e -> e.accept(this, context)).collect(Collectors.toList())); } @Override - public RexNode visit(Expression.DateLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.DateLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.TimestampLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.TimestampLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(TimestampTZLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final TimestampTZLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(PrecisionTimestampLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final PrecisionTimestampLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value(), expr.precision()), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(PrecisionTimestampTZLiteral expr, Context context) throws RuntimeException { + public RexNode visit(final PrecisionTimestampTZLiteral expr, final Context context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value(), expr.precision()), typeConverter.toCalcite(typeFactory, expr.getType())); } - private TimestampString getTimestampString(long microSec) { + private TimestampString getTimestampString(final long microSec) { return getTimestampString(microSec, 6); } - private TimestampString getTimestampString(long value, int precision) { + private TimestampString getTimestampString(final long value, final int precision) { switch (precision) { case 0: return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(value)); case 3: { - long seconds = TimeUnit.MILLISECONDS.toSeconds(value); - int fracSecondsInNano = + final long seconds = TimeUnit.MILLISECONDS.toSeconds(value); + final int fracSecondsInNano = (int) (TimeUnit.MILLISECONDS.toNanos(value) - TimeUnit.SECONDS.toNanos(seconds)); return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds)) .withNanos(fracSecondsInNano); } case 6: { - long seconds = TimeUnit.MICROSECONDS.toSeconds(value); - int fracSecondsInNano = + final long seconds = TimeUnit.MICROSECONDS.toSeconds(value); + final int fracSecondsInNano = (int) (TimeUnit.MICROSECONDS.toNanos(value) - TimeUnit.SECONDS.toNanos(seconds)); return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds)) .withNanos(fracSecondsInNano); } case 9: { - long seconds = TimeUnit.NANOSECONDS.toSeconds(value); - int fracSecondsInNano = (int) (value - TimeUnit.SECONDS.toNanos(seconds)); + final long seconds = TimeUnit.NANOSECONDS.toSeconds(value); + final int fracSecondsInNano = (int) (value - TimeUnit.SECONDS.toNanos(seconds)); return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds)) .withNanos(fracSecondsInNano); } @@ -287,16 +305,16 @@ private TimestampString getTimestampString(long value, int precision) { } @Override - public RexNode visit(Expression.IntervalYearLiteral expr, Context context) + public RexNode visit(final Expression.IntervalYearLiteral expr, final Context context) throws RuntimeException { return rexBuilder.makeIntervalLiteral( new BigDecimal(expr.years() * 12 + expr.months()), YEAR_MONTH_INTERVAL); } @Override - public RexNode visit(Expression.IntervalDayLiteral expr, Context context) + public RexNode visit(final Expression.IntervalDayLiteral expr, final Context context) throws RuntimeException { - long milliseconds = + final long milliseconds = expr.precision() > 3 ? (expr.subseconds() / (int) Math.pow(10, expr.precision() - 3)) : (expr.subseconds() * (int) Math.pow(10, 3 - expr.precision())); @@ -307,29 +325,33 @@ public RexNode visit(Expression.IntervalDayLiteral expr, Context context) } @Override - public RexNode visit(Expression.DecimalLiteral expr, Context context) throws RuntimeException { - byte[] value = expr.value().toByteArray(); - BigDecimal decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale(), 16); + public RexNode visit(final Expression.DecimalLiteral expr, final Context context) + throws RuntimeException { + final byte[] value = expr.value().toByteArray(); + final BigDecimal decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale(), 16); return rexBuilder.makeLiteral(decimal, typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.ListLiteral expr, Context context) throws RuntimeException { - List args = + public RexNode visit(final Expression.ListLiteral expr, final Context context) + throws RuntimeException { + final List args = expr.values().stream().map(l -> l.accept(this, context)).collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args); } @Override - public RexNode visit(Expression.EmptyListLiteral expr, Context context) throws RuntimeException { - RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + public RexNode visit(final Expression.EmptyListLiteral expr, final Context context) + throws RuntimeException { + final RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); return rexBuilder.makeCall( calciteType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, Collections.emptyList()); } @Override - public RexNode visit(Expression.MapLiteral expr, Context context) throws RuntimeException { - List args = + public RexNode visit(final Expression.MapLiteral expr, final Context context) + throws RuntimeException { + final List args = expr.values().entrySet().stream() .flatMap( entry -> @@ -341,25 +363,26 @@ public RexNode visit(Expression.MapLiteral expr, Context context) throws Runtime } @Override - public RexNode visit(Expression.IfThen expr, Context context) throws RuntimeException { + public RexNode visit(final Expression.IfThen expr, final Context context) + throws RuntimeException { // In Calcite, the arguments to the CASE operator are given as: // ... ... - Stream ifThenArgs = + final Stream ifThenArgs = expr.ifClauses().stream() .flatMap( clause -> Stream.of( clause.condition().accept(this, context), clause.then().accept(this, context))); - Stream elseArg = Stream.of(expr.elseClause().accept(this, context)); - List args = Stream.concat(ifThenArgs, elseArg).collect(Collectors.toList()); + final Stream elseArg = Stream.of(expr.elseClause().accept(this, context)); + final List args = Stream.concat(ifThenArgs, elseArg).collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } @Override - public RexNode visit(Switch expr, Context context) throws RuntimeException { - RexNode match = expr.match().accept(this, context); - Stream caseThenArgs = + public RexNode visit(final Switch expr, final Context context) throws RuntimeException { + final RexNode match = expr.match().accept(this, context); + final Stream caseThenArgs = expr.switchClauses().stream() .flatMap( clause -> @@ -369,15 +392,15 @@ public RexNode visit(Switch expr, Context context) throws RuntimeException { match, clause.condition().accept(this, context)), clause.then().accept(this, context))); - Stream defaultArg = Stream.of(expr.defaultClause().accept(this, context)); - List args = Stream.concat(caseThenArgs, defaultArg).collect(Collectors.toList()); + final Stream defaultArg = Stream.of(expr.defaultClause().accept(this, context)); + final List args = Stream.concat(caseThenArgs, defaultArg).collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } @Override - public RexNode visit(Expression.ScalarFunctionInvocation expr, Context context) + public RexNode visit(final Expression.ScalarFunctionInvocation expr, final Context context) throws RuntimeException { - SqlOperator operator = + final SqlOperator operator = scalarFunctionConverter .getSqlOperatorFromSubstraitFunc(expr.declaration().key(), expr.outputType()) .orElseThrow( @@ -386,27 +409,27 @@ public RexNode visit(Expression.ScalarFunctionInvocation expr, Context context) callConversionFailureMessage( "scalar", expr.declaration().name(), expr.arguments()))); - List eArgs = scalarFunctionConverter.getExpressionArguments(expr); - List args = + final List eArgs = scalarFunctionConverter.getExpressionArguments(expr); + final List args = IntStream.range(0, eArgs.size()) .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this, context)) .collect(Collectors.toList()); - RelDataType returnType = typeConverter.toCalcite(typeFactory, expr.outputType()); + final RelDataType returnType = typeConverter.toCalcite(typeFactory, expr.outputType()); return rexBuilder.makeCall(returnType, operator, args); } private String callConversionFailureMessage( - String functionType, String name, List args) { + final String functionType, final String name, final List args) { return String.format( "Unable to convert %s function %s(%s).", functionType, name, args.stream().map(this::convert).collect(Collectors.joining(", "))); } @Override - public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) + public RexNode visit(final Expression.WindowFunctionInvocation expr, final Context context) throws RuntimeException { - SqlOperator operator = + final SqlOperator operator = windowFunctionConverter .getSqlOperatorFromSubstraitFunc(expr.declaration().key(), expr.outputType()) .orElseThrow( @@ -415,40 +438,40 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) callConversionFailureMessage( "window", expr.declaration().name(), expr.arguments()))); - RelDataType outputType = typeConverter.toCalcite(typeFactory, expr.outputType()); + final RelDataType outputType = typeConverter.toCalcite(typeFactory, expr.outputType()); - List eArgs = expr.arguments(); - List args = + final List eArgs = expr.arguments(); + final List args = IntStream.range(0, eArgs.size()) .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this, context)) .collect(Collectors.toList()); - List partitionKeys = + final List partitionKeys = expr.partitionBy().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); - ImmutableList orderKeys = + final ImmutableList orderKeys = expr.sort().stream() .map( sf -> { - Set direction = asSqlKind(sf.direction()); + final Set direction = asSqlKind(sf.direction()); return new RexFieldCollation(sf.expr().accept(this, context), direction); }) .collect(ImmutableList.toImmutableList()); - RexWindowBound lowerBound = ToRexWindowBound.lowerBound(rexBuilder, expr.lowerBound()); - RexWindowBound upperBound = ToRexWindowBound.upperBound(rexBuilder, expr.upperBound()); + final RexWindowBound lowerBound = ToRexWindowBound.lowerBound(rexBuilder, expr.lowerBound()); + final RexWindowBound upperBound = ToRexWindowBound.upperBound(rexBuilder, expr.upperBound()); - boolean rowMode = isRowMode(expr); - boolean distinct = isDistinct(expr); + final boolean rowMode = isRowMode(expr); + final boolean distinct = isDistinct(expr); // For queries like: SELECT last_value() IGNORE NULLS OVER ... // Substrait has no mechanism to set this, so by default it is false - boolean ignoreNulls = false; + final boolean ignoreNulls = false; // These both control a rewrite rule within rexBuilder.makeOver that rewrites the given // expression into a case expression. These values are set as such to avoid this rewrite. - boolean nullWhenCountZero = false; - boolean allowPartial = true; + final boolean nullWhenCountZero = false; + final boolean allowPartial = true; return rexBuilder.makeOver( outputType, @@ -465,7 +488,7 @@ public RexNode visit(Expression.WindowFunctionInvocation expr, Context context) ignoreNulls); } - private Set asSqlKind(Expression.SortDirection direction) { + private Set asSqlKind(final Expression.SortDirection direction) { switch (direction) { case ASC_NULLS_FIRST: return Set.of(SqlKind.NULLS_FIRST); @@ -482,8 +505,8 @@ private Set asSqlKind(Expression.SortDirection direction) { } } - private boolean isRowMode(Expression.WindowFunctionInvocation expr) { - Expression.WindowBoundsType boundsType = expr.boundsType(); + private boolean isRowMode(final Expression.WindowFunctionInvocation expr) { + final Expression.WindowBoundsType boundsType = expr.boundsType(); switch (boundsType) { case ROWS: @@ -498,8 +521,8 @@ private boolean isRowMode(Expression.WindowFunctionInvocation expr) { } } - private boolean isDistinct(Expression.WindowFunctionInvocation expr) { - Expression.AggregationInvocation invocation = expr.invocation(); + private boolean isDistinct(final Expression.WindowFunctionInvocation expr) { + final Expression.AggregationInvocation invocation = expr.invocation(); switch (invocation) { case UNSPECIFIED: @@ -513,11 +536,12 @@ private boolean isDistinct(Expression.WindowFunctionInvocation expr) { } @Override - public RexNode visit(Expression.InPredicate expr, Context context) throws RuntimeException { - List needles = + public RexNode visit(final Expression.InPredicate expr, final Context context) + throws RuntimeException { + final List needles = expr.needles().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); context.incrementSubqueryDepth(); - RelNode rel = expr.haystack().accept(relNodeConverter, context); + final RelNode rel = expr.haystack().accept(relNodeConverter, context); context.decrementSubqueryDepth(); return RexSubQuery.in(rel, ImmutableList.copyOf(needles)); } @@ -528,48 +552,48 @@ static class ToRexWindowBound private final RexBuilder rexBuilder; private final RexWindowBound unboundedVariant; - static RexWindowBound lowerBound(RexBuilder rexBuilder, WindowBound bound) { + static RexWindowBound lowerBound(final RexBuilder rexBuilder, final WindowBound bound) { // per the spec, unbounded on the lower bound means the start of the partition // thus UNBOUNDED_PRECEDING should be used when bound is unbounded return bound.accept(new ToRexWindowBound(rexBuilder, RexWindowBounds.UNBOUNDED_PRECEDING)); } - static RexWindowBound upperBound(RexBuilder rexBuilder, WindowBound bound) { + static RexWindowBound upperBound(final RexBuilder rexBuilder, final WindowBound bound) { // per the spec, unbounded on the upper bound means the end of the partition // thus UNBOUNDED_FOLLOWING should be used when bound is unbounded return bound.accept(new ToRexWindowBound(rexBuilder, RexWindowBounds.UNBOUNDED_FOLLOWING)); } - private ToRexWindowBound(RexBuilder rexBuilder, RexWindowBound unboundedVariant) { + private ToRexWindowBound(final RexBuilder rexBuilder, final RexWindowBound unboundedVariant) { this.rexBuilder = rexBuilder; this.unboundedVariant = unboundedVariant; } @Override - public RexWindowBound visit(WindowBound.Preceding preceding) { - BigDecimal offset = BigDecimal.valueOf(preceding.offset()); + public RexWindowBound visit(final WindowBound.Preceding preceding) { + final BigDecimal offset = BigDecimal.valueOf(preceding.offset()); return RexWindowBounds.preceding(rexBuilder.makeBigintLiteral(offset)); } @Override - public RexWindowBound visit(WindowBound.Following following) { - BigDecimal offset = BigDecimal.valueOf(following.offset()); + public RexWindowBound visit(final WindowBound.Following following) { + final BigDecimal offset = BigDecimal.valueOf(following.offset()); return RexWindowBounds.following(rexBuilder.makeBigintLiteral(offset)); } @Override - public RexWindowBound visit(WindowBound.CurrentRow currentRow) { + public RexWindowBound visit(final WindowBound.CurrentRow currentRow) { return RexWindowBounds.CURRENT_ROW; } @Override - public RexWindowBound visit(WindowBound.Unbounded unbounded) { + public RexWindowBound visit(final WindowBound.Unbounded unbounded) { return unboundedVariant; } } - private String convert(FunctionArg a) { - String v; + private String convert(final FunctionArg a) { + final String v; if (a instanceof EnumArg) { v = ((EnumArg) a).value().toString(); } else if (a instanceof Expression) { @@ -583,8 +607,8 @@ private String convert(FunctionArg a) { } @Override - public RexNode visit(Expression.Cast expr, Context context) throws RuntimeException { - boolean safeCast = expr.failureBehavior() == FailureBehavior.RETURN_NULL; + public RexNode visit(final Expression.Cast expr, final Context context) throws RuntimeException { + final boolean safeCast = expr.failureBehavior() == FailureBehavior.RETURN_NULL; return rexBuilder.makeAbstractCast( typeConverter.toCalcite(typeFactory, expr.getType()), expr.input().accept(this, context), @@ -592,7 +616,7 @@ public RexNode visit(Expression.Cast expr, Context context) throws RuntimeExcept } @Override - public RexNode visit(FieldReference expr, Context context) throws RuntimeException { + public RexNode visit(final FieldReference expr, final Context context) throws RuntimeException { if (expr.isSimpleRootReference()) { final ReferenceSegment segment = expr.segments().get(0); @@ -636,7 +660,7 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti } @Override - public RexNode visitFallback(Expression expr, Context context) { + public RexNode visitFallback(final Expression expr, final Context context) { throw new UnsupportedOperationException( String.format( "Expression %s of type %s not handled by visitor type %s.", @@ -645,13 +669,17 @@ public RexNode visitFallback(Expression expr, Context context) { @Override public RexNode visitExpr( - SimpleExtension.Function fnDef, int argIdx, Expression e, Context context) + final SimpleExtension.Function fnDef, + final int argIdx, + final Expression e, + final Context context) throws RuntimeException { return e.accept(this, context); } @Override - public RexNode visitType(SimpleExtension.Function fnDef, int argIdx, Type t, Context context) + public RexNode visitType( + final SimpleExtension.Function fnDef, final int argIdx, final Type t, final Context context) throws RuntimeException { throw new UnsupportedOperationException( String.format( @@ -661,7 +689,10 @@ public RexNode visitType(SimpleExtension.Function fnDef, int argIdx, Type t, Con @Override public RexNode visitEnumArg( - SimpleExtension.Function fnDef, int argIdx, EnumArg e, Context context) + final SimpleExtension.Function fnDef, + final int argIdx, + final EnumArg e, + final Context context) throws RuntimeException { return EnumConverter.toRex(rexBuilder, fnDef, argIdx, e) @@ -674,17 +705,17 @@ public RexNode visitEnumArg( } @Override - public RexNode visit(ScalarSubquery expr, Context context) throws RuntimeException { + public RexNode visit(final ScalarSubquery expr, final Context context) throws RuntimeException { context.incrementSubqueryDepth(); - RelNode inputRelnode = expr.input().accept(relNodeConverter, context); + final RelNode inputRelnode = expr.input().accept(relNodeConverter, context); context.decrementSubqueryDepth(); return RexSubQuery.scalar(inputRelnode); } @Override - public RexNode visit(SetPredicate expr, Context context) throws RuntimeException { + public RexNode visit(final SetPredicate expr, final Context context) throws RuntimeException { context.incrementSubqueryDepth(); - RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context); + final RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context); context.decrementSubqueryDepth(); switch (expr.predicateOp()) { case PREDICATE_OP_EXISTS: diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java index 14d3b8dfa..663c025a6 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FieldSelectionConverter.java @@ -21,20 +21,20 @@ public class FieldSelectionConverter implements CallConverter { private final TypeConverter typeConverter; - public FieldSelectionConverter(TypeConverter typeConverter) { + public FieldSelectionConverter(final TypeConverter typeConverter) { super(); this.typeConverter = typeConverter; } @Override public Optional convert( - RexCall call, Function topLevelConverter) { + final RexCall call, final Function topLevelConverter) { if (!(call.getKind() == SqlKind.ITEM)) { return Optional.empty(); } - RexNode toDereference = call.getOperands().get(0); - RexNode reference = call.getOperands().get(1); + final RexNode toDereference = call.getOperands().get(0); + final RexNode reference = call.getOperands().get(1); if (reference.getKind() != SqlKind.LITERAL || !(reference instanceof RexLiteral)) { LOGGER @@ -46,14 +46,14 @@ public Optional convert( return Optional.empty(); } - Literal literal = (new LiteralConverter(typeConverter)).convert((RexLiteral) reference); + final Literal literal = (new LiteralConverter(typeConverter)).convert((RexLiteral) reference); - Expression input = topLevelConverter.apply(toDereference); + final Expression input = topLevelConverter.apply(toDereference); switch (toDereference.getType().getSqlTypeName()) { case ROW: { - Optional index = toInt(literal); + final Optional index = toInt(literal); if (index.isEmpty()) { return Optional.empty(); } @@ -65,7 +65,7 @@ public Optional convert( } case ARRAY: { - Optional index = toInt(literal); + final Optional index = toInt(literal); if (index.isEmpty()) { return Optional.empty(); } @@ -79,12 +79,12 @@ public Optional convert( case MAP: { - Optional mapKey = toString(literal); + final Optional mapKey = toString(literal); if (mapKey.isEmpty()) { return Optional.empty(); } - Expression.Literal keyLiteral = ExpressionCreator.string(false, mapKey.get()); + final Expression.Literal keyLiteral = ExpressionCreator.string(false, mapKey.get()); if (input instanceof FieldReference) { return Optional.of(((FieldReference) input).dereferenceMap(keyLiteral)); } else { @@ -96,7 +96,7 @@ public Optional convert( return Optional.empty(); } - private Optional toInt(Expression.Literal l) { + private Optional toInt(final Expression.Literal l) { if (l instanceof Expression.I8Literal) { return Optional.of(((Expression.I8Literal) l).value()); } else if (l instanceof Expression.I16Literal) { @@ -110,7 +110,7 @@ private Optional toInt(Expression.Literal l) { return Optional.empty(); } - public Optional toString(Expression.Literal l) { + public Optional toString(final Expression.Literal l) { if (!(l instanceof Expression.FixedCharLiteral)) { LOGGER.atWarn().log("Literal expected to be char type but was not. {}", l); return Optional.empty(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index d738ef157..bd28d1c15 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -89,7 +89,7 @@ public abstract class FunctionConverter< * @param functions the list of function variants to register * @param typeFactory the Calcite type factory */ - public FunctionConverter(List functions, RelDataTypeFactory typeFactory) { + public FunctionConverter(final List functions, final RelDataTypeFactory typeFactory) { this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT); } @@ -106,49 +106,49 @@ public FunctionConverter(List functions, RelDataTypeFactory typeFactory) { * @param typeConverter the type converter to use */ public FunctionConverter( - List functions, - List additionalSignatures, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final List functions, + final List additionalSignatures, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { this.rexBuilder = new RexBuilder(typeFactory); this.typeConverter = typeConverter; - List signatures = + final List signatures = new ArrayList<>(getSigs().size() + additionalSignatures.size()); signatures.addAll(additionalSignatures); signatures.addAll(getSigs()); this.typeFactory = typeFactory; this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create(); - ArrayListMultimap nameToFn = ArrayListMultimap.create(); - for (F f : functions) { + final ArrayListMultimap nameToFn = ArrayListMultimap.create(); + for (final F f : functions) { nameToFn.put(f.name().toLowerCase(Locale.ROOT), f); } - Multimap calciteOperators = + final Multimap calciteOperators = signatures.stream() .collect( Multimaps.toMultimap( FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create)); - IdentityHashMap matcherMap = + final IdentityHashMap matcherMap = new IdentityHashMap(); - for (String key : nameToFn.keySet()) { - Collection sigs = calciteOperators.get(key); + for (final String key : nameToFn.keySet()) { + final Collection sigs = calciteOperators.get(key); if (sigs.isEmpty()) { LOGGER.atDebug().log("No binding for function: {}", key); } - for (Sig sig : sigs) { - List implList = nameToFn.get(key); + for (final Sig sig : sigs) { + final List implList = nameToFn.get(key); if (!implList.isEmpty()) { matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList)); } } } - for (Entry entry : nameToFn.entries()) { - String key = entry.getKey(); - F func = entry.getValue(); - for (FunctionMappings.Sig sig : calciteOperators.get(key)) { + for (final Entry entry : nameToFn.entries()) { + final String key = entry.getKey(); + final F func = entry.getValue(); + for (final FunctionMappings.Sig sig : calciteOperators.get(key)) { substraitFuncKeyToSqlOperatorMap.put(func.key(), sig.operator()); } } @@ -167,9 +167,10 @@ public FunctionConverter( * @param outputType the expected output type * @return the matching {@link SqlOperator}, or empty if no match found */ - public Optional getSqlOperatorFromSubstraitFunc(String key, Type outputType) { - Map resolver = getTypeBasedResolver(); - Collection operators = substraitFuncKeyToSqlOperatorMap.get(key); + public Optional getSqlOperatorFromSubstraitFunc( + final String key, final Type outputType) { + final Map resolver = getTypeBasedResolver(); + final Collection operators = substraitFuncKeyToSqlOperatorMap.get(key); if (operators.isEmpty()) { return Optional.empty(); } @@ -180,8 +181,8 @@ public Optional getSqlOperatorFromSubstraitFunc(String key, Type ou } // at least 2 operators. Use output type to resolve SqlOperator. - String outputTypeStr = outputType.accept(ToTypeString.INSTANCE); - List resolvedOperators = + final String outputTypeStr = outputType.accept(ToTypeString.INSTANCE); + final List resolvedOperators = operators.stream() .filter( operator -> @@ -214,7 +215,8 @@ protected class FunctionFinder { private final Optional> singularInputType; private final Util.IntRange argRange; - public FunctionFinder(String substraitName, SqlOperator operator, List functions) { + public FunctionFinder( + final String substraitName, final SqlOperator operator, final List functions) { this.substraitName = substraitName; this.operator = operator; this.functions = functions; @@ -223,9 +225,9 @@ public FunctionFinder(String substraitName, SqlOperator operator, List functi functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(), functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt()); this.singularInputType = getSingularInputType(functions); - ImmutableListMultimap.Builder directMap = ImmutableListMultimap.builder(); - for (F func : functions) { - String key = func.key(); + final ImmutableListMultimap.Builder directMap = ImmutableListMultimap.builder(); + for (final F func : functions) { + final String key = func.key(); directMap.put(key, func); if (func.requiredArguments().size() != func.args().size()) { directMap.put(F.constructKey(substraitName, func.requiredArguments()), func); @@ -234,13 +236,13 @@ public FunctionFinder(String substraitName, SqlOperator operator, List functi this.directMap = directMap.build(); } - public boolean allowedArgCount(int count) { + public boolean allowedArgCount(final int count) { return argRange.within(count); } - private Optional signatureMatch(List inputTypes, Type outputType) { - for (F function : functions) { - List args = function.requiredArguments(); + private Optional signatureMatch(final List inputTypes, final Type outputType) { + for (final F function : functions) { + final List args = function.requiredArguments(); // Make sure that arguments & return are within bounds and match the types if (function.returnType() instanceof ParameterizedType && isMatch(outputType, (ParameterizedType) function.returnType()) @@ -266,12 +268,12 @@ && inputTypesMatchDefinedArguments(inputTypes, args)) { * @return true if the {@code inputTypes} satisfy the {@code args}, false otherwise */ private boolean inputTypesMatchDefinedArguments( - List inputTypes, List args) { + final List inputTypes, final List args) { - Map> wildcardToType = new HashMap<>(); + final Map> wildcardToType = new HashMap<>(); for (int i = 0; i < inputTypes.size(); i++) { - Type givenType = inputTypes.get(i); - SimpleExtension.ValueArgument wantType = + final Type givenType = inputTypes.get(i); + final SimpleExtension.ValueArgument wantType = (SimpleExtension.ValueArgument) args.get( // Variadic arguments should match the last argument's type @@ -304,20 +306,20 @@ private boolean inputTypesMatchDefinedArguments( *

    If this exists, the function finder will attempt to find a least-restrictive match using * these. */ - private Optional> getSingularInputType(List functions) { - List> matchers = new ArrayList<>(); - for (F f : functions) { + private Optional> getSingularInputType(final List functions) { + final List> matchers = new ArrayList<>(); + for (final F f : functions) { ParameterizedType firstType = null; // determine if all the required arguments are the of the same type. If so, - for (Argument a : f.requiredArguments()) { + for (final Argument a : f.requiredArguments()) { if (!(a instanceof SimpleExtension.ValueArgument)) { firstType = null; break; } - ParameterizedType pt = ((SimpleExtension.ValueArgument) a).value(); + final ParameterizedType pt = ((SimpleExtension.ValueArgument) a).value(); if (firstType == null) { firstType = pt; @@ -345,9 +347,9 @@ private Optional> getSingularInputType(List functi } } - private SingularArgumentMatcher singular(F function, ParameterizedType type) { + private SingularArgumentMatcher singular(final F function, final ParameterizedType type) { return (inputType, outputType) -> { - boolean check = isMatch(inputType, type); + final boolean check = isMatch(inputType, type); if (check) { return Optional.of(function); } @@ -355,10 +357,10 @@ private SingularArgumentMatcher singular(F function, ParameterizedType type) }; } - private SingularArgumentMatcher chained(List> matchers) { + private SingularArgumentMatcher chained(final List> matchers) { return (inputType, outputType) -> { - for (SingularArgumentMatcher s : matchers) { - Optional outcome = s.tryMatch(inputType, outputType); + for (final SingularArgumentMatcher s : matchers) { + final Optional outcome = s.tryMatch(inputType, outputType); if (outcome.isPresent()) { return outcome; } @@ -372,14 +374,14 @@ private SingularArgumentMatcher chained(List> matc * In case of a `RexLiteral` of an Enum value try both `req` and `op` signatures * for that argument position. */ - private Stream matchKeys(List rexOperands, List opTypes) { + private Stream matchKeys(final List rexOperands, final List opTypes) { assert (rexOperands.size() == opTypes.size()); if (rexOperands.isEmpty()) { return Stream.of(""); } else { - List> argTypeLists = + final List> argTypeLists = Streams.zip( rexOperands.stream(), opTypes.stream(), @@ -410,7 +412,8 @@ private Stream matchKeys(List rexOperands, List opTypes * @param topLevelConverter function to convert RexNode operands to Substrait Expressions * @return the matched Substrait function binding, or empty if no match found */ - public Optional attemptMatch(C call, Function topLevelConverter) { + public Optional attemptMatch( + final C call, final Function topLevelConverter) { /* * Here the RexLiteral with an Enum value is mapped to String Literal. @@ -421,39 +424,40 @@ public Optional attemptMatch(C call, Function topLevelCo * Note that if there are multiple registered function extensions which can match a particular Call, * the last one added to the extension collection will be matched. */ - List operandsList = call.getOperands().collect(Collectors.toList()); - List operands = + final List operandsList = call.getOperands().collect(Collectors.toList()); + final List operands = call.getOperands().map(topLevelConverter).collect(Collectors.toList()); - List opTypes = operands.stream().map(Expression::getType).collect(Collectors.toList()); + final List opTypes = + operands.stream().map(Expression::getType).collect(Collectors.toList()); - Type outputType = typeConverter.toSubstrait(call.getType()); + final Type outputType = typeConverter.toSubstrait(call.getType()); // try to do a direct match - List typeStrings = + final List typeStrings = opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList()); - Stream possibleKeys = matchKeys(operandsList, typeStrings); + final Stream possibleKeys = matchKeys(operandsList, typeStrings); - Optional directMatchKey = + final Optional directMatchKey = possibleKeys .map(argList -> substraitName + ":" + argList) .filter(directMap::containsKey) .findFirst(); if (directMatchKey.isPresent()) { - List variants = directMap.get(directMatchKey.get()); + final List variants = directMap.get(directMatchKey.get()); if (variants.isEmpty()) { return Optional.empty(); } - F variant = variants.get(variants.size() - 1); + final F variant = variants.get(variants.size() - 1); variant.validateOutputType(operands, outputType); - List funcArgs = + final List funcArgs = IntStream.range(0, operandsList.size()) .mapToObj( i -> { - RexNode r = operandsList.get(i); - Expression o = operands.get(i); + final RexNode r = operandsList.get(i); + final Expression o = operands.get(i); if (EnumConverter.isEnumValue(r)) { return EnumConverter.fromRex(variant, (RexLiteral) r, i).orElse(null); } else { @@ -461,7 +465,8 @@ public Optional attemptMatch(C call, Function topLevelCo } }) .collect(Collectors.toList()); - boolean allArgsMapped = funcArgs.stream().filter(Objects::isNull).findFirst().isEmpty(); + final boolean allArgsMapped = + funcArgs.stream().filter(Objects::isNull).findFirst().isEmpty(); if (allArgsMapped) { return Optional.of(generateBinding(call, variant, funcArgs, outputType)); } else { @@ -470,11 +475,11 @@ public Optional attemptMatch(C call, Function topLevelCo } if (singularInputType.isPresent()) { - Optional coerced = matchCoerced(call, outputType, operands); + final Optional coerced = matchCoerced(call, outputType, operands); if (coerced.isPresent()) { return coerced; } - Optional leastRestrictive = matchByLeastRestrictive(call, outputType, operands); + final Optional leastRestrictive = matchByLeastRestrictive(call, outputType, operands); if (leastRestrictive.isPresent()) { return leastRestrictive; } @@ -483,39 +488,40 @@ public Optional attemptMatch(C call, Function topLevelCo } private Optional matchByLeastRestrictive( - C call, Type outputType, List operands) { - RelDataType leastRestrictive = + final C call, final Type outputType, final List operands) { + final RelDataType leastRestrictive = typeFactory.leastRestrictive( call.getOperands().map(RexNode::getType).collect(Collectors.toList())); if (leastRestrictive == null) { return Optional.empty(); } - Type type = typeConverter.toSubstrait(leastRestrictive); - Optional out = singularInputType.orElseThrow().tryMatch(type, outputType); + final Type type = typeConverter.toSubstrait(leastRestrictive); + final Optional out = singularInputType.orElseThrow().tryMatch(type, outputType); return out.map( declaration -> { - List coercedArgs = coerceArguments(operands, type); + final List coercedArgs = coerceArguments(operands, type); declaration.validateOutputType(coercedArgs, outputType); return generateBinding(call, out.get(), coercedArgs, outputType); }); } - private Optional matchCoerced(C call, Type outputType, List expressions) { + private Optional matchCoerced( + final C call, final Type outputType, final List expressions) { // Convert the operands to the proper Substrait type - List operandTypes = + final List operandTypes = call.getOperands() .map(RexNode::getType) .map(typeConverter::toSubstrait) .collect(Collectors.toList()); // See if all the input types can be made to match the function - Optional matchFunction = signatureMatch(operandTypes, outputType); + final Optional matchFunction = signatureMatch(operandTypes, outputType); if (matchFunction.isEmpty()) { return Optional.empty(); } - List coercedArgs = + final List coercedArgs = Streams.zip( expressions.stream(), operandTypes.stream(), FunctionConverter::coerceArgument) .collect(Collectors.toList()); @@ -541,11 +547,12 @@ public interface GenericCall { * Coerced types according to an expected output type. Coercion is only done for type mismatches, * not for nullability or parameter mismatches. */ - private static List coerceArguments(List arguments, Type targetType) { + private static List coerceArguments( + final List arguments, final Type targetType) { return arguments.stream().map(a -> coerceArgument(a, targetType)).collect(Collectors.toList()); } - private static Expression coerceArgument(Expression argument, Type type) { + private static Expression coerceArgument(final Expression argument, final Type type) { if (isMatch(type, argument.getType())) { return argument; } @@ -561,7 +568,8 @@ private interface SingularArgumentMatcher { Optional tryMatch(Type type, Type outputType); } - private static boolean isMatch(ParameterizedType actualType, ParameterizedType targetType) { + private static boolean isMatch( + final ParameterizedType actualType, final ParameterizedType targetType) { if (targetType.isWildcard()) { return true; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 8bb41ff39..2bc168f76 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -146,15 +146,15 @@ public class FunctionMappings { SqlStdOperatorTable.BIT_LEFT_SHIFT, resolver(SqlStdOperatorTable.BIT_LEFT_SHIFT, Set.of("i8", "i16", "i32", "i64"))); - public static void main(String[] args) { + public static void main(final String[] args) { SCALAR_SIGS.forEach(System.out::println); } - public static Sig s(SqlOperator operator, String substraitName) { + public static Sig s(final SqlOperator operator, final String substraitName) { return new Sig(operator, substraitName.toLowerCase(Locale.ROOT)); } - public static Sig s(SqlOperator operator) { + public static Sig s(final SqlOperator operator) { return s(operator, operator.getName().toLowerCase(Locale.ROOT)); } @@ -176,7 +176,7 @@ public SqlOperator operator() { } } - public static TypeBasedResolver resolver(SqlOperator operator, Set outTypes) { + public static TypeBasedResolver resolver(final SqlOperator operator, final Set outTypes) { return new TypeBasedResolver(operator, outTypes); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index f8b4be1dd..a1505f9d3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -9,229 +9,230 @@ public class IgnoreNullableAndParameters private final ParameterizedType typeToMatch; - public IgnoreNullableAndParameters(ParameterizedType typeToMatch) { + public IgnoreNullableAndParameters(final ParameterizedType typeToMatch) { this.typeToMatch = typeToMatch; } @Override - public Boolean visit(Type.Bool type) { + public Boolean visit(final Type.Bool type) { return typeToMatch instanceof Type.Bool; } @Override - public Boolean visit(Type.I8 type) { + public Boolean visit(final Type.I8 type) { return typeToMatch instanceof Type.I8; } @Override - public Boolean visit(Type.I16 type) { + public Boolean visit(final Type.I16 type) { return typeToMatch instanceof Type.I16; } @Override - public Boolean visit(Type.I32 type) { + public Boolean visit(final Type.I32 type) { return typeToMatch instanceof Type.I32; } @Override - public Boolean visit(Type.I64 type) { + public Boolean visit(final Type.I64 type) { return typeToMatch instanceof Type.I64; } @Override - public Boolean visit(Type.FP32 type) { + public Boolean visit(final Type.FP32 type) { return typeToMatch instanceof Type.FP32; } @Override - public Boolean visit(Type.FP64 type) { + public Boolean visit(final Type.FP64 type) { return typeToMatch instanceof Type.FP64; } @Override - public Boolean visit(Type.Str type) { + public Boolean visit(final Type.Str type) { return typeToMatch instanceof Type.Str; } @Override - public Boolean visit(Type.Binary type) { + public Boolean visit(final Type.Binary type) { return typeToMatch instanceof Type.Binary; } @Override - public Boolean visit(Type.Date type) { + public Boolean visit(final Type.Date type) { return typeToMatch instanceof Type.Date; } @Override - public Boolean visit(Type.Time type) { + public Boolean visit(final Type.Time type) { return typeToMatch instanceof Type.Time; } @Override - public Boolean visit(Type.TimestampTZ type) { + public Boolean visit(final Type.TimestampTZ type) { return typeToMatch instanceof Type.TimestampTZ; } @Override - public Boolean visit(Type.Timestamp type) { + public Boolean visit(final Type.Timestamp type) { return typeToMatch instanceof Type.Timestamp; } @Override - public Boolean visit(Type.IntervalYear type) { + public Boolean visit(final Type.IntervalYear type) { return typeToMatch instanceof Type.IntervalYear; } @Override - public Boolean visit(Type.IntervalDay type) { + public Boolean visit(final Type.IntervalDay type) { return typeToMatch instanceof Type.IntervalDay || typeToMatch instanceof ParameterizedType.IntervalDay; } @Override - public Boolean visit(Type.IntervalCompound type) { + public Boolean visit(final Type.IntervalCompound type) { return typeToMatch instanceof Type.IntervalCompound || typeToMatch instanceof ParameterizedType.IntervalCompound; } @Override - public Boolean visit(Type.UUID type) { + public Boolean visit(final Type.UUID type) { return typeToMatch instanceof Type.UUID; } @Override - public Boolean visit(Type.UserDefined type) throws RuntimeException { + public Boolean visit(final Type.UserDefined type) throws RuntimeException { // Two user-defined types are equal if they have the same uri AND name return typeToMatch.equals(type); } @Override - public Boolean visit(Type.FixedChar type) { + public Boolean visit(final Type.FixedChar type) { return typeToMatch instanceof Type.FixedChar || typeToMatch instanceof ParameterizedType.FixedChar; } @Override - public Boolean visit(Type.VarChar type) { + public Boolean visit(final Type.VarChar type) { return typeToMatch instanceof Type.VarChar || typeToMatch instanceof ParameterizedType.VarChar; } @Override - public Boolean visit(Type.FixedBinary type) { + public Boolean visit(final Type.FixedBinary type) { return typeToMatch instanceof Type.FixedBinary || typeToMatch instanceof ParameterizedType.FixedBinary; } @Override - public Boolean visit(Type.Decimal type) { + public Boolean visit(final Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } @Override - public Boolean visit(Type.PrecisionTime type) { + public Boolean visit(final Type.PrecisionTime type) { return typeToMatch instanceof Type.PrecisionTime || typeToMatch instanceof ParameterizedType.PrecisionTime; } @Override - public Boolean visit(Type.PrecisionTimestamp type) { + public Boolean visit(final Type.PrecisionTimestamp type) { return typeToMatch instanceof Type.PrecisionTimestamp || typeToMatch instanceof ParameterizedType.PrecisionTimestamp; } @Override - public Boolean visit(Type.PrecisionTimestampTZ type) { + public Boolean visit(final Type.PrecisionTimestampTZ type) { return typeToMatch instanceof Type.PrecisionTimestampTZ || typeToMatch instanceof ParameterizedType.PrecisionTimestampTZ; } @Override - public Boolean visit(Type.Struct type) { + public Boolean visit(final Type.Struct type) { return typeToMatch instanceof Type.Struct || typeToMatch instanceof ParameterizedType.Struct; } @Override - public Boolean visit(Type.ListType type) { + public Boolean visit(final Type.ListType type) { return typeToMatch instanceof Type.ListType || typeToMatch instanceof ParameterizedType.ListType; } @Override - public Boolean visit(Type.Map type) { + public Boolean visit(final Type.Map type) { return typeToMatch instanceof Type.Map || typeToMatch instanceof ParameterizedType.Map; } @Override - public Boolean visit(ParameterizedType.FixedChar expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.FixedChar expr) throws RuntimeException { return typeToMatch instanceof Type.FixedChar || typeToMatch instanceof ParameterizedType.FixedChar; } @Override - public Boolean visit(ParameterizedType.VarChar expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.VarChar expr) throws RuntimeException { return typeToMatch instanceof Type.VarChar || typeToMatch instanceof ParameterizedType.VarChar; } @Override - public Boolean visit(ParameterizedType.FixedBinary expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.FixedBinary expr) throws RuntimeException { return typeToMatch instanceof Type.FixedBinary || typeToMatch instanceof ParameterizedType.FixedBinary; } @Override - public Boolean visit(ParameterizedType.Decimal expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.Decimal expr) throws RuntimeException { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } @Override - public Boolean visit(ParameterizedType.IntervalDay expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.IntervalDay expr) throws RuntimeException { return typeToMatch instanceof Type.IntervalDay || typeToMatch instanceof ParameterizedType.IntervalDay; } @Override - public Boolean visit(ParameterizedType.IntervalCompound expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.IntervalCompound expr) throws RuntimeException { return typeToMatch instanceof Type.IntervalCompound || typeToMatch instanceof ParameterizedType.IntervalCompound; } @Override - public Boolean visit(ParameterizedType.PrecisionTime expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.PrecisionTime expr) throws RuntimeException { return typeToMatch instanceof Type.PrecisionTime || typeToMatch instanceof ParameterizedType.PrecisionTime; } @Override - public Boolean visit(ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.PrecisionTimestamp expr) throws RuntimeException { return typeToMatch instanceof Type.PrecisionTimestamp || typeToMatch instanceof ParameterizedType.PrecisionTimestamp; } @Override - public Boolean visit(ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException { return typeToMatch instanceof Type.PrecisionTimestampTZ || typeToMatch instanceof ParameterizedType.PrecisionTimestampTZ; } @Override - public Boolean visit(ParameterizedType.Struct expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.Struct expr) throws RuntimeException { return typeToMatch instanceof Type.Struct || typeToMatch instanceof ParameterizedType.Struct; } @Override - public Boolean visit(ParameterizedType.ListType expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.ListType expr) throws RuntimeException { return typeToMatch instanceof Type.ListType || typeToMatch instanceof ParameterizedType.ListType; } @Override - public Boolean visit(ParameterizedType.Map expr) throws RuntimeException { + public Boolean visit(final ParameterizedType.Map expr) throws RuntimeException { return typeToMatch instanceof Type.Map || typeToMatch instanceof ParameterizedType.Map; } @Override - public Boolean visit(ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { + public Boolean visit(final ParameterizedType.StringLiteral stringLiteral) + throws RuntimeException { return false; } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ListSqlOperatorFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ListSqlOperatorFunctions.java index 429e90c8b..02cf6832f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ListSqlOperatorFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ListSqlOperatorFunctions.java @@ -12,8 +12,8 @@ public class ListSqlOperatorFunctions { - public static void main(String[] args) { - Map operators = + public static void main(final String[] args) { + final Map operators = Arrays.stream(SqlStdOperatorTable.class.getFields()) .filter( f -> { @@ -35,10 +35,10 @@ public static void main(String[] args) { System.out.println("Operator count: " + operators.size()); } - private static SqlOperator toOp(Field f) { + private static SqlOperator toOp(final Field f) { try { return (SqlOperator) f.get(null); - } catch (IllegalAccessException e) { + } catch (final IllegalAccessException e) { throw new IllegalStateException(e); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java index a0f6b88d1..3437962df 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConstructorConverter.java @@ -21,14 +21,14 @@ public class LiteralConstructorConverter implements CallConverter { private final TypeConverter typeConverter; - public LiteralConstructorConverter(TypeConverter typeConverter) { + public LiteralConstructorConverter(final TypeConverter typeConverter) { this.typeConverter = typeConverter; } @Override public Optional convert( - RexCall call, Function topLevelConverter) { - SqlOperator operator = call.getOperator(); + final RexCall call, final Function topLevelConverter) { + final SqlOperator operator = call.getOperator(); if (operator instanceof SqlArrayValueConstructor) { return call.getOperands().isEmpty() ? toEmptyListLiteral(call) @@ -40,12 +40,12 @@ public Optional convert( } private Optional toMapLiteral( - RexCall call, Function topLevelConverter) { - List literals = + final RexCall call, final Function topLevelConverter) { + final List literals = call.operands.stream() .map(t -> ((Expression.Literal) topLevelConverter.apply(t))) .collect(java.util.stream.Collectors.toList()); - Map items = new HashMap<>(); + final Map items = new HashMap<>(); assert literals.size() % 2 == 0; for (int i = 0; i < literals.size(); i += 2) { items.put(literals.get(i), literals.get(i + 1)); @@ -54,7 +54,7 @@ private Optional toMapLiteral( } private Optional toNonEmptyListLiteral( - RexCall call, Function topLevelConverter) { + final RexCall call, final Function topLevelConverter) { return Optional.of( ExpressionCreator.list( call.getType().isNullable(), @@ -63,9 +63,9 @@ private Optional toNonEmptyListLiteral( .collect(java.util.stream.Collectors.toList()))); } - private Optional toEmptyListLiteral(RexCall call) { - RelDataType calciteElementType = call.getType().getComponentType(); - Type substraitElementType = typeConverter.toSubstrait(calciteElementType); + private Optional toEmptyListLiteral(final RexCall call) { + final RelDataType calciteElementType = call.getType().getComponentType(); + final Type substraitElementType = typeConverter.toSubstrait(calciteElementType); return Optional.of( ExpressionCreator.emptyList(call.getType().isNullable(), substraitElementType)); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java index 02cb8a116..8bacbe232 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/LiteralConverter.java @@ -50,23 +50,23 @@ public class LiteralConverter { private final TypeConverter typeConverter; - public LiteralConverter(TypeConverter typeConverter) { + public LiteralConverter(final TypeConverter typeConverter) { this.typeConverter = typeConverter; } - private static BigDecimal i(RexLiteral literal) { + private static BigDecimal i(final RexLiteral literal) { return bd(literal).setScale(0, RoundingMode.HALF_UP); } - private static String s(RexLiteral literal) { + private static String s(final RexLiteral literal) { return ((NlsString) literal.getValue()).getValue(); } - private static BigDecimal bd(RexLiteral literal) { + private static BigDecimal bd(final RexLiteral literal) { return (BigDecimal) literal.getValue(); } - public Expression.Literal convert(RexLiteral literal) { + public Expression.Literal convert(final RexLiteral literal) { // convert type first to guarantee we can handle the value. final Type type = typeConverter.toSubstrait(literal.getType()); final boolean n = type.nullable(); @@ -88,9 +88,9 @@ public Expression.Literal convert(RexLiteral literal) { return ExpressionCreator.bool(n, literal.getValueAs(Boolean.class)); case CHAR: { - Comparable val = literal.getValue(); + final Comparable val = literal.getValue(); if (val instanceof NlsString) { - NlsString nls = (NlsString) val; + final NlsString nls = (NlsString) val; return ExpressionCreator.fixedChar(n, nls.getValue()); } throw new UnsupportedOperationException("Unable to handle char type: " + val); @@ -103,7 +103,7 @@ public Expression.Literal convert(RexLiteral literal) { case DECIMAL: { - BigDecimal bd = bd(literal); + final BigDecimal bd = bd(literal); return ExpressionCreator.decimal( n, bd, literal.getType().getPrecision(), literal.getType().getScale()); } @@ -126,13 +126,13 @@ public Expression.Literal convert(RexLiteral literal) { return ExpressionCreator.binary(n, ByteString.copyFrom(literal.getValueAs(byte[].class))); case SYMBOL: { - Object value = literal.getValue(); + final Object value = literal.getValue(); // case TimeUnitRange tur -> string(n, tur.name()); if (value instanceof NlsString) { return ExpressionCreator.string(n, ((NlsString) value).getValue()); } else if (value instanceof Enum) { - Enum v = (Enum) value; - Optional r = + final Enum v = (Enum) value; + final Optional r = EnumConverter.canConvert(v) ? Optional.of(ExpressionCreator.string(n, v.name())) : Optional.empty(); @@ -144,21 +144,23 @@ public Expression.Literal convert(RexLiteral literal) { } case DATE: { - DateString date = literal.getValueAs(DateString.class); - LocalDate localDate = LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); + final DateString date = literal.getValueAs(DateString.class); + final LocalDate localDate = + LocalDate.parse(date.toString(), CALCITE_LOCAL_DATE_FORMATTER); return ExpressionCreator.date(n, (int) localDate.toEpochDay()); } case TIME: { - TimeString time = literal.getValueAs(TimeString.class); - LocalTime localTime = LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); + final TimeString time = literal.getValueAs(TimeString.class); + final LocalTime localTime = + LocalTime.parse(time.toString(), CALCITE_LOCAL_TIME_FORMATTER); return ExpressionCreator.time(n, TimeUnit.NANOSECONDS.toMicros(localTime.toNanoOfDay())); } case TIMESTAMP: case TIMESTAMP_WITH_LOCAL_TIME_ZONE: { - TimestampString timestamp = literal.getValueAs(TimestampString.class); - LocalDateTime ldt = + final TimestampString timestamp = literal.getValueAs(TimestampString.class); + final LocalDateTime ldt = LocalDateTime.parse(timestamp.toString(), CALCITE_LOCAL_DATETIME_FORMATTER); return ExpressionCreator.timestamp(n, ldt); } @@ -166,9 +168,9 @@ public Expression.Literal convert(RexLiteral literal) { case INTERVAL_YEAR_MONTH: case INTERVAL_MONTH: { - long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); - long years = intervalLength / 12; - long months = intervalLength - years * 12; + final long intervalLength = Objects.requireNonNull(literal.getValueAs(Long.class)); + final long years = intervalLength / 12; + final long months = intervalLength - years * 12; return ExpressionCreator.intervalYear(n, (int) years, (int) months); } case INTERVAL_DAY: @@ -183,26 +185,26 @@ public Expression.Literal convert(RexLiteral literal) { case INTERVAL_SECOND: { // Calcite represents day/time intervals in milliseconds, despite a default scale of 6. - Long totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); - Duration interval = Duration.ofMillis(totalMillis); + final Long totalMillis = Objects.requireNonNull(literal.getValueAs(Long.class)); + final Duration interval = Duration.ofMillis(totalMillis); - long days = interval.toDays(); - long seconds = interval.minusDays(days).toSeconds(); - int micros = interval.toMillisPart() * 1000; + final long days = interval.toDays(); + final long seconds = interval.minusDays(days).toSeconds(); + final int micros = interval.toMillisPart() * 1000; return ExpressionCreator.intervalDay(n, (int) days, (int) seconds, micros, 6); } case ROW: { - List literals = (List) literal.getValue(); + final List literals = (List) literal.getValue(); return ExpressionCreator.struct( n, literals.stream().map(this::convert).collect(Collectors.toList())); } case ARRAY: { - List literals = (List) literal.getValue(); + final List literals = (List) literal.getValue(); return ExpressionCreator.list( n, literals.stream().map(this::convert).collect(Collectors.toList())); } @@ -216,11 +218,11 @@ public Expression.Literal convert(RexLiteral literal) { } public static byte[] padRightIfNeeded( - org.apache.calcite.avatica.util.ByteString bytes, int length) { + final org.apache.calcite.avatica.util.ByteString bytes, final int length) { return padRightIfNeeded(bytes.getBytes(), length); } - public static byte[] padRightIfNeeded(byte[] value, int length) { + public static byte[] padRightIfNeeded(final byte[] value, final int length) { if (length < value.length) { throw new IllegalArgumentException( @@ -231,7 +233,7 @@ public static byte[] padRightIfNeeded(byte[] value, int length) { return value; } - byte[] newArray = new byte[length]; + final byte[] newArray = new byte[length]; System.arraycopy(value, 0, newArray, 0, value.length); return newArray; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 6993c8451..4deb0c10c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -37,21 +37,22 @@ public class RexExpressionConverter implements RexVisitor { private final List callConverters; private final SubstraitRelVisitor relVisitor; private final TypeConverter typeConverter; - private WindowFunctionConverter windowFunctionConverter; + private final WindowFunctionConverter windowFunctionConverter; - public RexExpressionConverter(SubstraitRelVisitor relVisitor, CallConverter... callConverters) { + public RexExpressionConverter( + final SubstraitRelVisitor relVisitor, final CallConverter... callConverters) { this(relVisitor, Arrays.asList(callConverters), null, TypeConverter.DEFAULT); } - public RexExpressionConverter(CallConverter... callConverters) { + public RexExpressionConverter(final CallConverter... callConverters) { this(null, Arrays.asList(callConverters), null, TypeConverter.DEFAULT); } public RexExpressionConverter( - SubstraitRelVisitor relVisitor, - List callConverters, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter) { + final SubstraitRelVisitor relVisitor, + final List callConverters, + final WindowFunctionConverter windowFunctionConverter, + final TypeConverter typeConverter) { this.callConverters = callConverters; this.relVisitor = relVisitor; this.windowFunctionConverter = windowFunctionConverter; @@ -67,15 +68,15 @@ public RexExpressionConverter() { } @Override - public Expression visitInputRef(RexInputRef inputRef) { + public Expression visitInputRef(final RexInputRef inputRef) { return FieldReference.newRootStructReference( inputRef.getIndex(), typeConverter.toSubstrait(inputRef.getType())); } @Override - public Expression visitCall(RexCall call) { - for (CallConverter c : callConverters) { - Optional out = c.convert(call, rexNode -> rexNode.accept(this)); + public Expression visitCall(final RexCall call) { + for (final CallConverter c : callConverters) { + final Optional out = c.convert(call, rexNode -> rexNode.accept(this)); if (out.isPresent()) { return out.get(); } @@ -84,7 +85,7 @@ public Expression visitCall(RexCall call) { throw new IllegalArgumentException(callConversionFailureMessage(call)); } - private String callConversionFailureMessage(RexCall call) { + private String callConversionFailureMessage(final RexCall call) { return String.format( "Unable to convert call %s(%s).", call.getOperator().getName(), @@ -94,12 +95,12 @@ private String callConversionFailureMessage(RexCall call) { } @Override - public Expression visitLiteral(RexLiteral literal) { + public Expression visitLiteral(final RexLiteral literal) { return (new LiteralConverter(typeConverter)).convert(literal); } @Override - public Expression visitOver(RexOver over) { + public Expression visitOver(final RexOver over) { if (over.ignoreNulls()) { throw new IllegalArgumentException("IGNORE NULLS cannot be expressed in Substrait"); } @@ -110,27 +111,27 @@ public Expression visitOver(RexOver over) { } @Override - public Expression visitCorrelVariable(RexCorrelVariable correlVariable) { + public Expression visitCorrelVariable(final RexCorrelVariable correlVariable) { throw new UnsupportedOperationException("RexCorrelVariable not supported"); } @Override - public Expression visitDynamicParam(RexDynamicParam dynamicParam) { + public Expression visitDynamicParam(final RexDynamicParam dynamicParam) { throw new UnsupportedOperationException("RexDynamicParam not supported"); } @Override - public Expression visitRangeRef(RexRangeRef rangeRef) { + public Expression visitRangeRef(final RexRangeRef rangeRef) { throw new UnsupportedOperationException("RexRangeRef not supported"); } @Override - public Expression visitFieldAccess(RexFieldAccess fieldAccess) { - SqlKind kind = fieldAccess.getReferenceExpr().getKind(); + public Expression visitFieldAccess(final RexFieldAccess fieldAccess) { + final SqlKind kind = fieldAccess.getReferenceExpr().getKind(); switch (kind) { case CORREL_VARIABLE: { - int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess); + final int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess); return FieldReference.newRootStructOuterReference( fieldAccess.getField().getIndex(), @@ -141,9 +142,9 @@ public Expression visitFieldAccess(RexFieldAccess fieldAccess) { case INPUT_REF: case FIELD_ACCESS: { - Expression expression = fieldAccess.getReferenceExpr().accept(this); + final Expression expression = fieldAccess.getReferenceExpr().accept(this); if (expression instanceof FieldReference) { - FieldReference nestedReference = (FieldReference) expression; + final FieldReference nestedReference = (FieldReference) expression; return nestedReference.dereferenceStruct(fieldAccess.getField().getIndex()); } else { return FieldReference.newStructReference(fieldAccess.getField().getIndex(), expression); @@ -156,8 +157,8 @@ public Expression visitFieldAccess(RexFieldAccess fieldAccess) { } @Override - public Expression visitSubQuery(RexSubQuery subQuery) { - Rel rel = relVisitor.apply(subQuery.rel); + public Expression visitSubQuery(final RexSubQuery subQuery) { + final Rel rel = relVisitor.apply(subQuery.rel); if (subQuery.getOperator() == SqlStdOperatorTable.EXISTS) { return Expression.SetPredicate.builder() @@ -175,8 +176,8 @@ public Expression visitSubQuery(RexSubQuery subQuery) { .type(typeConverter.toSubstrait(subQuery.getType())) .build(); } else if (subQuery.getOperator() == SqlStdOperatorTable.IN) { - List needles = new ArrayList<>(); - for (RexNode inOperand : subQuery.getOperands()) { + final List needles = new ArrayList<>(); + for (final RexNode inOperand : subQuery.getOperands()) { needles.add(inOperand.accept(this)); } return Expression.InPredicate.builder().needles(needles).haystack(rel).build(); @@ -186,32 +187,32 @@ public Expression visitSubQuery(RexSubQuery subQuery) { } @Override - public Expression visitTableInputRef(RexTableInputRef fieldRef) { + public Expression visitTableInputRef(final RexTableInputRef fieldRef) { throw new UnsupportedOperationException("RexTableInputRef not supported"); } @Override - public Expression visitLocalRef(RexLocalRef localRef) { + public Expression visitLocalRef(final RexLocalRef localRef) { throw new UnsupportedOperationException("RexLocalRef not supported"); } @Override - public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { + public Expression visitPatternFieldRef(final RexPatternFieldRef fieldRef) { throw new UnsupportedOperationException("RexPatternFieldRef not supported"); } @Override - public Expression visitLambda(RexLambda rexLambda) { + public Expression visitLambda(final RexLambda rexLambda) { throw new UnsupportedOperationException("RexLambda not supported"); } @Override - public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { + public Expression visitLambdaRef(final RexLambdaRef rexLambdaRef) { throw new UnsupportedOperationException("RexLambdaRef not supported"); } @Override - public Expression visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + public Expression visitNodeAndFieldIndex(final RexNodeAndFieldIndex nodeAndFieldIndex) { throw new UnsupportedOperationException("RexNodeAndFieldIndex not supported"); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java index b3ad6514c..bb03d4062 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ScalarFunctionConverter.java @@ -32,15 +32,16 @@ public class ScalarFunctionConverter private final List mappers; public ScalarFunctionConverter( - List functions, RelDataTypeFactory typeFactory) { + final List functions, + final RelDataTypeFactory typeFactory) { this(functions, Collections.emptyList(), typeFactory, TypeConverter.DEFAULT); } public ScalarFunctionConverter( - List functions, - List additionalSignatures, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final List functions, + final List additionalSignatures, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { super(functions, additionalSignatures, typeFactory, typeConverter); mappers = List.of(new TrimFunctionMapper(functions), new SqrtFunctionMapper(functions)); @@ -53,7 +54,7 @@ protected ImmutableList getSigs() { @Override public Optional convert( - RexCall call, Function topLevelConverter) { + final RexCall call, final Function topLevelConverter) { // If a mapping applies to this call, use it; otherwise default behavior. return getMappingForCall(call) .map(mapping -> mappedConvert(mapping, call, topLevelConverter)) @@ -69,12 +70,12 @@ private Optional getMappingForCall(final RexCall call) } private Optional mappedConvert( - SubstraitFunctionMapping mapping, - RexCall call, - Function topLevelConverter) { - FunctionFinder finder = + final SubstraitFunctionMapping mapping, + final RexCall call, + final Function topLevelConverter) { + final FunctionFinder finder = new FunctionFinder(mapping.substraitName(), call.op, mapping.functions()); - WrappedScalarCall wrapped = + final WrappedScalarCall wrapped = new WrappedScalarCall(call) { @Override public Stream getOperands() { @@ -86,17 +87,17 @@ public Stream getOperands() { } private Optional defaultConvert( - RexCall call, Function topLevelConverter) { - FunctionFinder finder = signatures.get(call.op); - WrappedScalarCall wrapped = new WrappedScalarCall(call); + final RexCall call, final Function topLevelConverter) { + final FunctionFinder finder = signatures.get(call.op); + final WrappedScalarCall wrapped = new WrappedScalarCall(call); return attemptMatch(finder, wrapped, topLevelConverter); } private Optional attemptMatch( - FunctionFinder finder, - WrappedScalarCall call, - Function topLevelConverter) { + final FunctionFinder finder, + final WrappedScalarCall call, + final Function topLevelConverter) { if (!isPotentialFunctionMatch(finder, call)) { return Optional.empty(); } @@ -104,16 +105,17 @@ private Optional attemptMatch( return finder.attemptMatch(call, topLevelConverter); } - private boolean isPotentialFunctionMatch(FunctionFinder finder, WrappedScalarCall call) { + private boolean isPotentialFunctionMatch( + final FunctionFinder finder, final WrappedScalarCall call) { return Objects.nonNull(finder) && finder.allowedArgCount((int) call.getOperands().count()); } @Override protected Expression generateBinding( - WrappedScalarCall call, - SimpleExtension.ScalarFunctionVariant function, - List arguments, - Type outputType) { + final WrappedScalarCall call, + final SimpleExtension.ScalarFunctionVariant function, + final List arguments, + final Type outputType) { return Expression.ScalarFunctionInvocation.builder() .outputType(outputType) .declaration(function) @@ -121,14 +123,15 @@ protected Expression generateBinding( .build(); } - public List getExpressionArguments(Expression.ScalarFunctionInvocation expression) { + public List getExpressionArguments( + final Expression.ScalarFunctionInvocation expression) { // If a mapping applies to this expression, use it to get the arguments; otherwise default // behavior. return getMappedExpressionArguments(expression).orElseGet(expression::arguments); } private Optional> getMappedExpressionArguments( - Expression.ScalarFunctionInvocation expression) { + final Expression.ScalarFunctionInvocation expression) { return mappers.stream() .map(mapper -> mapper.getExpressionArguments(expression)) .filter(Optional::isPresent) @@ -140,7 +143,7 @@ protected static class WrappedScalarCall implements FunctionConverter.GenericCal private final RexCall delegate; - private WrappedScalarCall(RexCall delegate) { + private WrappedScalarCall(final RexCall delegate) { this.delegate = delegate; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java index 773632164..dadc02115 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java @@ -9,15 +9,16 @@ public class SortFieldConverter { /** Converts a {@link RexFieldCollation} to a {@link Expression.SortField}. */ public static Expression.SortField toSortField( - RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) { - Expression expr = rexFieldCollation.left.accept(rexExpressionConverter); - Expression.SortDirection direction = asSortDirection(rexFieldCollation); + final RexFieldCollation rexFieldCollation, + final RexExpressionConverter rexExpressionConverter) { + final Expression expr = rexFieldCollation.left.accept(rexExpressionConverter); + final Expression.SortDirection direction = asSortDirection(rexFieldCollation); return Expression.SortField.builder().expr(expr).direction(direction).build(); } - private static Expression.SortDirection asSortDirection(RexFieldCollation collation) { - RelFieldCollation.Direction direction = collation.getDirection(); + private static Expression.SortDirection asSortDirection(final RexFieldCollation collation) { + final RelFieldCollation.Direction direction = collation.getDirection(); if (direction == Direction.ASCENDING) { return collation.getNullDirection() == RelFieldCollation.NullDirection.LAST diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SqrtFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SqrtFunctionMapper.java index e34ceca02..0452cb8d8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/SqrtFunctionMapper.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SqrtFunctionMapper.java @@ -21,7 +21,7 @@ final class SqrtFunctionMapper implements ScalarFunctionMapper { private static final String sqrtFunctionName = "sqrt"; private final List sqrtFunctions; - public SqrtFunctionMapper(List functions) { + public SqrtFunctionMapper(final List functions) { this.sqrtFunctions = functions.stream() .filter(f -> sqrtFunctionName.equalsIgnoreCase(f.name())) @@ -29,13 +29,13 @@ public SqrtFunctionMapper(List functions) { } @Override - public Optional toSubstrait(RexCall call) { + public Optional toSubstrait(final RexCall call) { if (sqrtFunctions.isEmpty()) { return Optional.empty(); } if (isPowerOfHalf(call)) { - List operands = call.getOperands().subList(0, 1); + final List operands = call.getOperands().subList(0, 1); return Optional.of(new SubstraitFunctionMapping(sqrtFunctionName, operands, sqrtFunctions)); } @@ -55,7 +55,7 @@ private static boolean isPowerOfHalf(final RexCall call) { if (!(exponent instanceof RexLiteral)) { return false; } - RexLiteral literal = (RexLiteral) exponent; + final RexLiteral literal = (RexLiteral) exponent; switch (literal.getType().getSqlTypeName()) { case DOUBLE: diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java index 06adbc662..069a5b4cf 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java @@ -41,7 +41,7 @@ private enum Trim { private final String substraitName; private final SqlTrimFunction.Flag flag; - Trim(String substraitName, SqlTrimFunction.Flag flag) { + Trim(final String substraitName, final SqlTrimFunction.Flag flag) { this.substraitName = substraitName; this.flag = flag; } @@ -54,21 +54,21 @@ public SqlTrimFunction.Flag flag() { return flag; } - public static Optional fromFlag(SqlTrimFunction.Flag flag) { + public static Optional fromFlag(final SqlTrimFunction.Flag flag) { return Arrays.stream(values()).filter(t -> t.flag == flag).findAny(); } - public static Optional fromSubstraitName(String name) { + public static Optional fromSubstraitName(final String name) { return Arrays.stream(values()).filter(t -> t.substraitName.equals(name)).findAny(); } } private final Map> trimFunctions; - public TrimFunctionMapper(List functions) { - Map> trims = new HashMap<>(); - for (Trim t : Trim.values()) { - List funcs = findFunction(t.substraitName(), functions); + public TrimFunctionMapper(final List functions) { + final Map> trims = new HashMap<>(); + for (final Trim t : Trim.values()) { + final List funcs = findFunction(t.substraitName(), functions); if (!funcs.isEmpty()) { trims.put(t, funcs); } @@ -77,7 +77,7 @@ public TrimFunctionMapper(List functions) { } private List findFunction( - String name, Collection functions) { + final String name, final Collection functions) { return functions.stream() .filter(f -> name.equals(f.name())) .collect(Collectors.toUnmodifiableList()); @@ -89,29 +89,29 @@ public Optional toSubstrait(final RexCall call) { return Optional.empty(); } - Optional trimType = getTrimCallType(call); + final Optional trimType = getTrimCallType(call); return trimType.map( trim -> { - List functions = trimFunctions.getOrDefault(trim, List.of()); + final List functions = trimFunctions.getOrDefault(trim, List.of()); if (functions.isEmpty()) { return null; } - String name = trim.substraitName(); - List operands = + final String name = trim.substraitName(); + final List operands = call.getOperands().stream().skip(1).collect(Collectors.toUnmodifiableList()); return new SubstraitFunctionMapping(name, operands, functions); }); } - private Optional getTrimCallType(RexCall call) { - RexNode trimType = call.operands.get(0); + private Optional getTrimCallType(final RexCall call) { + final RexNode trimType = call.operands.get(0); if (trimType.getType().getSqlTypeName() != SqlTypeName.SYMBOL) { return Optional.empty(); } - Comparable value = ((RexLiteral) trimType).getValue(); + final Comparable value = ((RexLiteral) trimType).getValue(); if (!(value instanceof SqlTrimFunction.Flag)) { return Optional.empty(); } @@ -122,14 +122,14 @@ private Optional getTrimCallType(RexCall call) { @Override public Optional> getExpressionArguments( final Expression.ScalarFunctionInvocation expression) { - String name = expression.declaration().name(); + final String name = expression.declaration().name(); return Trim.fromSubstraitName(name) .map(Trim::flag) .map(SqlTrimFunction.Flag::name) .map(EnumArg::of) .map( trimTypeArg -> { - LinkedList args = new LinkedList<>(expression.arguments()); + final LinkedList args = new LinkedList<>(expression.arguments()); args.addFirst(trimTypeArg); return args; }); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java index 8979b0d26..107712e99 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java @@ -10,19 +10,19 @@ public class WindowBoundConverter { /** Converts a {@link RexWindowBound} to a {@link WindowBound}. */ - public static WindowBound toWindowBound(RexWindowBound rexWindowBound) { + public static WindowBound toWindowBound(final RexWindowBound rexWindowBound) { if (rexWindowBound.isCurrentRow()) { return WindowBound.CURRENT_ROW; } if (rexWindowBound.isUnbounded()) { return WindowBound.UNBOUNDED; } else { - RexNode node = rexWindowBound.getOffset(); + final RexNode node = rexWindowBound.getOffset(); if (node instanceof RexLiteral) { - RexLiteral literal = (RexLiteral) node; + final RexLiteral literal = (RexLiteral) node; if (SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { - BigDecimal offset = (BigDecimal) literal.getValue4(); + final BigDecimal offset = (BigDecimal) literal.getValue4(); if (rexWindowBound.isPreceding()) { return WindowBound.Preceding.of(offset.longValue()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java index 9e35492e3..0d65c60c2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java @@ -36,48 +36,49 @@ protected ImmutableList getSigs() { } public WindowFunctionConverter( - List functions, RelDataTypeFactory typeFactory) { + final List functions, + final RelDataTypeFactory typeFactory) { super(functions, typeFactory); } public WindowFunctionConverter( - List functions, - List additionalSignatures, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final List functions, + final List additionalSignatures, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { super(functions, additionalSignatures, typeFactory, typeConverter); } @Override protected Expression.WindowFunctionInvocation generateBinding( - WrappedWindowCall call, - SimpleExtension.WindowFunctionVariant function, - List arguments, - Type outputType) { - RexOver over = call.over; - RexWindow window = over.getWindow(); - - List partitionExprs = + final WrappedWindowCall call, + final SimpleExtension.WindowFunctionVariant function, + final List arguments, + final Type outputType) { + final RexOver over = call.over; + final RexWindow window = over.getWindow(); + + final List partitionExprs = window.partitionKeys.stream() .map(r -> r.accept(call.rexExpressionConverter)) .collect(java.util.stream.Collectors.toList()); - List sorts = + final List sorts = window.orderKeys != null ? window.orderKeys.stream() .map(rfc -> toSortField(rfc, call.rexExpressionConverter)) .collect(java.util.stream.Collectors.toList()) : Collections.emptyList(); - Expression.AggregationInvocation invocation = + final Expression.AggregationInvocation invocation = over.isDistinct() ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; // Calcite only supports ROW or RANGE mode - Expression.WindowBoundsType boundsType = + final Expression.WindowBoundsType boundsType = window.isRows() ? Expression.WindowBoundsType.ROWS : Expression.WindowBoundsType.RANGE; - WindowBound lowerBound = toWindowBound(window.getLowerBound()); - WindowBound upperBound = toWindowBound(window.getUpperBound()); + final WindowBound lowerBound = toWindowBound(window.getLowerBound()); + final WindowBound upperBound = toWindowBound(window.getUpperBound()); return ExpressionCreator.windowFunction( function, @@ -93,14 +94,14 @@ protected Expression.WindowFunctionInvocation generateBinding( } public Optional convert( - RexOver over, - Function topLevelConverter, - RexExpressionConverter rexExpressionConverter) { - SqlAggFunction aggFunction = over.getAggOperator(); + final RexOver over, + final Function topLevelConverter, + final RexExpressionConverter rexExpressionConverter) { + final SqlAggFunction aggFunction = over.getAggOperator(); - SqlAggFunction lookupFunction = + final SqlAggFunction lookupFunction = AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); - FunctionFinder m = signatures.get(lookupFunction); + final FunctionFinder m = signatures.get(lookupFunction); if (m == null) { return Optional.empty(); } @@ -108,7 +109,7 @@ public Optional convert( return Optional.empty(); } - WrappedWindowCall wrapped = new WrappedWindowCall(over, rexExpressionConverter); + final WrappedWindowCall wrapped = new WrappedWindowCall(over, rexExpressionConverter); return m.attemptMatch(wrapped, topLevelConverter); } @@ -116,7 +117,8 @@ static class WrappedWindowCall implements FunctionConverter.GenericCall { private final RexOver over; private final RexExpressionConverter rexExpressionConverter; - private WrappedWindowCall(RexOver over, RexExpressionConverter rexExpressionConverter) { + private WrappedWindowCall( + final RexOver over, final RexExpressionConverter rexExpressionConverter) { this.over = over; this.rexExpressionConverter = rexExpressionConverter; } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java index b1b9a201f..888b21c74 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java @@ -35,36 +35,37 @@ protected ImmutableList getSigs() { } public WindowRelFunctionConverter( - List functions, RelDataTypeFactory typeFactory) { + final List functions, + final RelDataTypeFactory typeFactory) { super(functions, typeFactory); } public WindowRelFunctionConverter( - List functions, - List additionalSignatures, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final List functions, + final List additionalSignatures, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { super(functions, additionalSignatures, typeFactory, typeConverter); } @Override protected ConsistentPartitionWindow.WindowRelFunctionInvocation generateBinding( - WrappedWindowRelCall call, - SimpleExtension.WindowFunctionVariant function, - List arguments, - Type outputType) { - Window.RexWinAggCall over = call.getWinAggCall(); + final WrappedWindowRelCall call, + final SimpleExtension.WindowFunctionVariant function, + final List arguments, + final Type outputType) { + final Window.RexWinAggCall over = call.getWinAggCall(); - Expression.AggregationInvocation invocation = + final Expression.AggregationInvocation invocation = over.distinct ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; // Calcite only supports ROW or RANGE mode - Expression.WindowBoundsType boundsType = + final Expression.WindowBoundsType boundsType = call.isRows() ? Expression.WindowBoundsType.ROWS : Expression.WindowBoundsType.RANGE; - WindowBound lowerBound = toWindowBound(call.getLowerBound()); - WindowBound upperBound = toWindowBound(call.getUpperBound()); + final WindowBound lowerBound = toWindowBound(call.getLowerBound()); + final WindowBound upperBound = toWindowBound(call.getUpperBound()); return ExpressionCreator.windowRelFunction( function, @@ -78,16 +79,16 @@ protected ConsistentPartitionWindow.WindowRelFunctionInvocation generateBinding( } public Optional convert( - Window.RexWinAggCall winAggCall, - RexWindowBound lowerBound, - RexWindowBound upperBound, - boolean isRows, - Function topLevelConverter) { - SqlAggFunction aggFunction = (SqlAggFunction) winAggCall.getOperator(); - - SqlAggFunction lookupFunction = + final Window.RexWinAggCall winAggCall, + final RexWindowBound lowerBound, + final RexWindowBound upperBound, + final boolean isRows, + final Function topLevelConverter) { + final SqlAggFunction aggFunction = (SqlAggFunction) winAggCall.getOperator(); + + final SqlAggFunction lookupFunction = AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); - FunctionFinder m = signatures.get(lookupFunction); + final FunctionFinder m = signatures.get(lookupFunction); if (m == null) { return Optional.empty(); } @@ -95,7 +96,7 @@ public Optional convert( return Optional.empty(); } - WrappedWindowRelCall wrapped = + final WrappedWindowRelCall wrapped = new WrappedWindowRelCall(winAggCall, lowerBound, upperBound, isRows); return m.attemptMatch(wrapped, topLevelConverter); } @@ -107,10 +108,10 @@ static class WrappedWindowRelCall implements FunctionConverter.GenericCall { private final boolean isRows; private WrappedWindowRelCall( - Window.RexWinAggCall winAggCall, - RexWindowBound lowerBound, - RexWindowBound upperBound, - boolean isRows) { + final Window.RexWinAggCall winAggCall, + final RexWindowBound lowerBound, + final RexWindowBound upperBound, + final boolean isRows) { this.winAggCall = winAggCall; this.lowerBound = lowerBound; this.upperBound = upperBound; diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java index dd70db0cb..e510033b6 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlDialect.java @@ -17,9 +17,9 @@ public class SubstraitSqlDialect extends SqlDialect { public static SqlDialect DEFAULT = new SubstraitSqlDialect(DEFAULT_CONTEXT); - public static SqlString toSql(RelNode relNode) { - RelToSqlConverter relToSql = new RelToSqlConverter(DEFAULT); - SqlNode sqlNode = relToSql.visitRoot(relNode).asStatement(); + public static SqlString toSql(final RelNode relNode) { + final RelToSqlConverter relToSql = new RelToSqlConverter(DEFAULT); + final SqlNode sqlNode = relToSql.visitRoot(relNode).asStatement(); return sqlNode.toSqlString( c -> c.withAlwaysUseParentheses(false) @@ -28,7 +28,7 @@ public static SqlString toSql(RelNode relNode) { .withIndentation(0)); } - public SubstraitSqlDialect(Context context) { + public SubstraitSqlDialect(final Context context) { super(context); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java index 0f7891b5b..fa2c581e5 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java @@ -29,8 +29,8 @@ public class SubstraitSqlStatementParser { * @return a list of {@link SqlNode}s corresponding to the given statements * @throws SqlParseException if there is an error while parsing the SQL statements */ - public static List parseStatements(String sqlStatements) throws SqlParseException { - SqlParser parser = SqlParser.create(sqlStatements, PARSER_CONFIG); + public static List parseStatements(final String sqlStatements) throws SqlParseException { + final SqlParser parser = SqlParser.create(sqlStatements, PARSER_CONFIG); return parser.parseStmtList(); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java index a87e29563..44a9b7e6e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java @@ -35,9 +35,10 @@ public class SubstraitSqlToCalcite { * @return a {@link RelRoot} corresponding to the given SQL statement * @throws SqlParseException if there is an error while parsing the SQL statement */ - public static RelRoot convertQuery(String sqlStatement, Prepare.CatalogReader catalogReader) + public static RelRoot convertQuery( + final String sqlStatement, final Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlValidator validator = new SubstraitSqlValidator(catalogReader); + final SqlValidator validator = new SubstraitSqlValidator(catalogReader); return convertQuery(sqlStatement, catalogReader, validator, createDefaultRelOptCluster()); } @@ -57,17 +58,17 @@ public static RelRoot convertQuery(String sqlStatement, Prepare.CatalogReader ca * @throws SqlParseException if there is an error while parsing the SQL statement string */ public static RelRoot convertQuery( - String sqlStatement, - Prepare.CatalogReader catalogReader, - SqlValidator validator, - RelOptCluster cluster) + final String sqlStatement, + final Prepare.CatalogReader catalogReader, + final SqlValidator validator, + final RelOptCluster cluster) throws SqlParseException { - List sqlNodes = SubstraitSqlStatementParser.parseStatements(sqlStatement); + final List sqlNodes = SubstraitSqlStatementParser.parseStatements(sqlStatement); if (sqlNodes.size() != 1) { throw new IllegalArgumentException( String.format("Expected one statement, found: %d", sqlNodes.size())); } - List relRoots = convert(sqlNodes, catalogReader, validator, cluster); + final List relRoots = convert(sqlNodes, catalogReader, validator, cluster); // as there was only 1 statement, there should only be 1 root return relRoots.get(0); } @@ -83,8 +84,9 @@ public static RelRoot convertQuery( * @throws SqlParseException if there is an error while parsing the SQL statements */ public static List convertQueries( - String sqlStatements, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlValidator validator = new SubstraitSqlValidator(catalogReader); + final String sqlStatements, final Prepare.CatalogReader catalogReader) + throws SqlParseException { + final SqlValidator validator = new SubstraitSqlValidator(catalogReader); return convertQueries(sqlStatements, catalogReader, validator, createDefaultRelOptCluster()); } @@ -105,22 +107,22 @@ public static List convertQueries( * @throws SqlParseException if there is an error while parsing the SQL statements */ public static List convertQueries( - String sqlStatements, - Prepare.CatalogReader catalogReader, - SqlValidator validator, - RelOptCluster cluster) + final String sqlStatements, + final Prepare.CatalogReader catalogReader, + final SqlValidator validator, + final RelOptCluster cluster) throws SqlParseException { - List sqlNodes = SubstraitSqlStatementParser.parseStatements(sqlStatements); + final List sqlNodes = SubstraitSqlStatementParser.parseStatements(sqlStatements); return convert(sqlNodes, catalogReader, validator, cluster); } static List convert( - List sqlNodes, - Prepare.CatalogReader catalogReader, - SqlValidator validator, - RelOptCluster cluster) { - RelOptTable.ViewExpander viewExpander = null; - SqlToRelConverter converter = + final List sqlNodes, + final Prepare.CatalogReader catalogReader, + final SqlValidator validator, + final RelOptCluster cluster) { + final RelOptTable.ViewExpander viewExpander = null; + final SqlToRelConverter converter = new SqlToRelConverter( viewExpander, validator, @@ -128,17 +130,17 @@ static List convert( cluster, StandardConvertletTable.INSTANCE, SqlToRelConverter.CONFIG); - DdlSqlToRelConverter ddlSqlToRelConverter = new DdlSqlToRelConverter(converter); + final DdlSqlToRelConverter ddlSqlToRelConverter = new DdlSqlToRelConverter(converter); return sqlNodes.stream() .map(sqlNode -> sqlNode.accept(ddlSqlToRelConverter)) .collect(Collectors.toList()); } static RelOptCluster createDefaultRelOptCluster() { - RexBuilder rexBuilder = + final RexBuilder rexBuilder = new RexBuilder(new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM)); - HepProgram program = HepProgram.builder().build(); - RelOptPlanner emptyPlanner = new HepPlanner(program); + final HepProgram program = HepProgram.builder().build(); + final RelOptPlanner emptyPlanner = new HepPlanner(program); return RelOptCluster.create(emptyPlanner, rexBuilder); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java index 52be2d6a5..4d5099bc3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlValidator.java @@ -9,7 +9,7 @@ public class SubstraitSqlValidator extends SqlValidatorImpl { static SqlValidator.Config CONFIG = Config.DEFAULT.withIdentifierExpansion(true); - public SubstraitSqlValidator(Prepare.CatalogReader catalogReader) { + public SubstraitSqlValidator(final Prepare.CatalogReader catalogReader) { super(SubstraitOperatorTable.INSTANCE, catalogReader, catalogReader.getTypeFactory(), CONFIG); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index 4f690e17f..b413521ee 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -22,25 +22,26 @@ class AggregationFunctionsTest extends PlanTestBase { static final TypeCreator N = TypeCreator.of(true); // Create a table with that has a column of every numeric type, both NOT NULL and NULL - private List numericTypesR = List.of(R.I8, R.I16, R.I32, R.I64, R.FP32, R.FP64); - private List numericTypesN = List.of(N.I8, N.I16, N.I32, N.I64, N.FP32, N.FP64); - private List numericTypes = + private final List numericTypesR = List.of(R.I8, R.I16, R.I32, R.I64, R.FP32, R.FP64); + private final List numericTypesN = List.of(N.I8, N.I16, N.I32, N.I64, N.FP32, N.FP64); + private final List numericTypes = Stream.concat(numericTypesR.stream(), numericTypesN.stream()).collect(Collectors.toList()); - private List tableTypes = + private final List tableTypes = Stream.concat( // Column to Group By Stream.of(N.I8), // Columns with Numeric Types numericTypes.stream()) .collect(Collectors.toList()); - private List columnNames = + private final List columnNames = Streams.mapWithIndex(tableTypes.stream(), (t, index) -> String.valueOf(index)) .collect(Collectors.toList()); - private NamedScan numericTypesTable = b.namedScan(List.of("example"), columnNames, tableTypes); + private final NamedScan numericTypesTable = + b.namedScan(List.of("example"), columnNames, tableTypes); // Create the given function call on the given field of the input - private Aggregate.Measure functionPicker(Rel input, int field, String fname) { + private Aggregate.Measure functionPicker(final Rel input, final int field, final String fname) { switch (fname) { case "min": return b.min(input, field); @@ -59,7 +60,7 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) { } // Create one function call per numeric type column - private List functions(Rel input, String fname) { + private List functions(final Rel input, final String fname) { // first column is for grouping, skip it return IntStream.range(1, tableTypes.size()) .boxed() @@ -69,8 +70,8 @@ private List functions(Rel input, String fname) { @ParameterizedTest @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) - void emptyGrouping(String aggFunction) { - Aggregate rel = + void emptyGrouping(final String aggFunction) { + final Aggregate rel = b.aggregate( input -> b.grouping(input), input -> functions(input, aggFunction), numericTypesTable); assertFullRoundTrip(rel); @@ -78,8 +79,8 @@ void emptyGrouping(String aggFunction) { @ParameterizedTest @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) - void withGrouping(String aggFunction) { - Aggregate rel = + void withGrouping(final String aggFunction) { + final Aggregate rel = b.aggregate( input -> b.grouping(input, 0), input -> functions(input, aggFunction), diff --git a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java index 9fbff8a96..255fafbcd 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ArithmeticFunctionTest.java @@ -11,8 +11,8 @@ class ArithmeticFunctionTest extends PlanTestBase { @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void arithmetic(String c) throws Exception { - String query = + void arithmetic(final String c) throws Exception { + final String query = String.format( "SELECT %s + %s, %s - %s, %s * %s, %s / %s FROM numbers", c, c, c, c, c, c, c, c); assertFullRoundTrip(query, CREATES); @@ -20,204 +20,204 @@ void arithmetic(String c) throws Exception { @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void abs(String column) throws Exception { - String query = String.format("SELECT abs(%s) FROM numbers", column); + void abs(final String column) throws Exception { + final String query = String.format("SELECT abs(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void exponential(String column) throws Exception { - String query = String.format("SELECT exp(%s) FROM numbers", column); + void exponential(final String column) throws Exception { + final String query = String.format("SELECT exp(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64"}) - void mod(String column) throws Exception { - String query = String.format("SELECT mod(%s, %s) FROM numbers", column, column); + void mod(final String column) throws Exception { + final String query = String.format("SELECT mod(%s, %s) FROM numbers", column, column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void negation(String column) throws Exception { - String query = String.format("SELECT -%s FROM numbers", column); + void negation(final String column) throws Exception { + final String query = String.format("SELECT -%s FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i64", "fp32", "fp64"}) - void power(String column) throws Exception { - String query = String.format("SELECT power(%s, %s) FROM numbers", column, column); + void power(final String column) throws Exception { + final String query = String.format("SELECT power(%s, %s) FROM numbers", column, column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"sin", "cos", "tan", "asin", "acos", "atan"}) - void trigonometric(String fname) throws Exception { - String query = String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fname, fname); + void trigonometric(final String fname) throws Exception { + final String query = String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fname, fname); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void atan2(String column) throws Exception { - String query = String.format("SELECT atan2(%s, %s) FROM numbers", column, column); + void atan2(final String column) throws Exception { + final String query = String.format("SELECT atan2(%s, %s) FROM numbers", column, column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void sign(String column) throws Exception { - String query = String.format("SELECT sign(%s) FROM numbers", column); + void sign(final String column) throws Exception { + final String query = String.format("SELECT sign(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void max(String column) throws Exception { - String query = String.format("SELECT max(%s) FROM numbers", column); + void max(final String column) throws Exception { + final String query = String.format("SELECT max(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void min(String column) throws Exception { - String query = String.format("SELECT min(%s) FROM numbers", column); + void min(final String column) throws Exception { + final String query = String.format("SELECT min(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void avg(String column) throws Exception { - String query = String.format("SELECT avg(%s) FROM numbers", column); + void avg(final String column) throws Exception { + final String query = String.format("SELECT avg(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void sum(String column) throws Exception { - String query = String.format("SELECT sum(%s) FROM numbers", column); + void sum(final String column) throws Exception { + final String query = String.format("SELECT sum(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void sum0(String column) throws Exception { - String query = String.format("SELECT sum0(%s) FROM numbers", column); + void sum0(final String column) throws Exception { + final String query = String.format("SELECT sum0(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i64", "fp32", "fp64"}) - void sqrt(String column) throws Exception { - String query = String.format("SELECT sqrt(%s) FROM numbers", column); + void sqrt(final String column) throws Exception { + final String query = String.format("SELECT sqrt(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void sinh(String column) throws Exception { - String query = String.format("SELECT SINH(%s) FROM numbers", column); + void sinh(final String column) throws Exception { + final String query = String.format("SELECT SINH(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void tanh(String column) throws Exception { - String query = String.format("SELECT TANH(%s) FROM numbers", column); + void tanh(final String column) throws Exception { + final String query = String.format("SELECT TANH(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void cosh(String column) throws Exception { - String query = String.format("SELECT COSH(%s) FROM numbers", column); + void cosh(final String column) throws Exception { + final String query = String.format("SELECT COSH(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void asinh(String column) throws Exception { - String query = String.format("SELECT ASINH(%s) FROM numbers", column); + void asinh(final String column) throws Exception { + final String query = String.format("SELECT ASINH(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void atanh(String column) throws Exception { - String query = String.format("SELECT ATANH(%s) FROM numbers", column); + void atanh(final String column) throws Exception { + final String query = String.format("SELECT ATANH(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void acosh(String column) throws Exception { - String query = String.format("SELECT ACOSH(%s) FROM numbers", column); + void acosh(final String column) throws Exception { + final String query = String.format("SELECT ACOSH(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64"}) - void bitwise_not_scalar(String column) throws Exception { - String query = String.format("SELECT BITNOT(%s) FROM numbers", column); + void bitwise_not_scalar(final String column) throws Exception { + final String query = String.format("SELECT BITNOT(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @CsvSource({"i8, 8", "i16, 160", "i32, 32000", "i64, CAST(6000000004 AS BIGINT)"}) - void bitwise_and_scalar(String column, String mask) throws Exception { - String query = String.format("SELECT BITAND(%s, %s) FROM numbers", column, mask); + void bitwise_and_scalar(final String column, final String mask) throws Exception { + final String query = String.format("SELECT BITAND(%s, %s) FROM numbers", column, mask); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @CsvSource({"i8, 8", "i16, 160", "i32, 32000", "i64, CAST(6000000004 AS BIGINT)"}) - void bitwise_xor_scalar(String column, String mask) throws Exception { - String query = String.format("SELECT BITXOR(%s, %s) FROM numbers", column, mask); + void bitwise_xor_scalar(final String column, final String mask) throws Exception { + final String query = String.format("SELECT BITXOR(%s, %s) FROM numbers", column, mask); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @CsvSource({"i8, 8", "i16, 160", "i32, 32000", "i64, CAST(6000000004 AS BIGINT)"}) - void bitwise_or_scalar(String column, String mask) throws Exception { - String query = String.format("SELECT BITOR(%s, %s) FROM numbers", column, mask); + void bitwise_or_scalar(final String column, final String mask) throws Exception { + final String query = String.format("SELECT BITOR(%s, %s) FROM numbers", column, mask); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void radians(String column) throws Exception { - String query = String.format("SELECT RADIANS(%s) FROM numbers", column); + void radians(final String column) throws Exception { + final String query = String.format("SELECT RADIANS(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void degrees(String column) throws Exception { - String query = String.format("SELECT DEGREES(%s) FROM numbers", column); + void degrees(final String column) throws Exception { + final String query = String.format("SELECT DEGREES(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i32", "i64"}) - void factorial(String column) throws Exception { - String query = String.format("SELECT FACTORIAL(%s) FROM numbers", column); + void factorial(final String column) throws Exception { + final String query = String.format("SELECT FACTORIAL(%s) FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64"}) - void bit_left_shift(String column) throws Exception { - String query = String.format("SELECT %s << 1 FROM numbers", column); + void bit_left_shift(final String column) throws Exception { + final String query = String.format("SELECT %s << 1 FROM numbers", column); assertFullRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64"}) - void leftshift(String column) throws Exception { - String query = String.format("SELECT LEFTSHIFT(%s, 1) FROM numbers", column); + void leftshift(final String column) throws Exception { + final String query = String.format("SELECT LEFTSHIFT(%s, 1) FROM numbers", column); assertFullRoundTrip(query, CREATES); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java index fe87d5fa1..ffb022001 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java @@ -97,23 +97,24 @@ void not() { test("not:bool", rex.makeCall(SqlStdOperatorTable.NOT, c(false, SqlTypeName.BOOLEAN))); } - private void test(String expectedName, RexNode call) { + private void test(final String expectedName, final RexNode call) { test(expectedName, call, c -> {}, true); } private void test( - String expectedName, - RexNode call, - Consumer consumer, - boolean bidirectional) { - Expression expression = call.accept(rexExpressionConverter); + final String expectedName, + final RexNode call, + final Consumer consumer, + final boolean bidirectional) { + final Expression expression = call.accept(rexExpressionConverter); assertTrue(expression instanceof Expression.ScalarFunctionInvocation); - Expression.ScalarFunctionInvocation func = (Expression.ScalarFunctionInvocation) expression; + final Expression.ScalarFunctionInvocation func = + (Expression.ScalarFunctionInvocation) expression; assertEquals(expectedName, func.declaration().key()); consumer.accept(func); if (bidirectional) { - RexNode convertedCall = expression.accept(expressionRexConverter, Context.newContext()); + final RexNode convertedCall = expression.accept(expressionRexConverter, Context.newContext()); assertEquals(call, convertedCall); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index e21d3b653..a711d7e8c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -93,7 +93,7 @@ void tStr() { @Test void tBinary() { - byte[] val = "my test".getBytes(StandardCharsets.UTF_8); + final byte[] val = "my test".getBytes(StandardCharsets.UTF_8); bitest( ExpressionCreator.binary(false, val), c(new org.apache.calcite.avatica.util.ByteString(val), SqlTypeName.VARBINARY)); @@ -108,9 +108,9 @@ void tTime() { @Test void tTimeWithMicroSecond() { - long microSec = (14L * 60 * 60 + 22 * 60 + 47) * 1000 * 1000 + 123456; - long seconds = TimeUnit.MICROSECONDS.toSeconds(microSec); - int fracSecondsInNano = + final long microSec = (14L * 60 * 60 + 22 * 60 + 47) * 1000 * 1000 + 123456; + final long seconds = TimeUnit.MICROSECONDS.toSeconds(microSec); + final int fracSecondsInNano = (int) (TimeUnit.MICROSECONDS.toNanos(microSec) - TimeUnit.SECONDS.toNanos(seconds)); assertEquals( TimeString.fromMillisOfDay((int) TimeUnit.SECONDS.toMillis(seconds)) @@ -138,17 +138,17 @@ void tDate() { @Test void tTimestamp() { - TimestampLiteral ts = ExpressionCreator.timestamp(false, 2002, 2, 14, 16, 20, 47, 123); - int nano = (int) TimeUnit.MICROSECONDS.toNanos(123); - TimestampString tsx = new TimestampString(2002, 2, 14, 16, 20, 47).withNanos(nano); + final TimestampLiteral ts = ExpressionCreator.timestamp(false, 2002, 2, 14, 16, 20, 47, 123); + final int nano = (int) TimeUnit.MICROSECONDS.toNanos(123); + final TimestampString tsx = new TimestampString(2002, 2, 14, 16, 20, 47).withNanos(nano); bitest(ts, rex.makeTimestampLiteral(tsx, 6)); } @Test void tTimestampWithMilliMacroSeconds() { - TimestampLiteral ts = ExpressionCreator.timestamp(false, 2002, 2, 14, 16, 20, 47, 123456); - int nano = (int) TimeUnit.MICROSECONDS.toNanos(123456); - TimestampString tsx = new TimestampString(2002, 2, 14, 16, 20, 47).withNanos(nano); + final TimestampLiteral ts = ExpressionCreator.timestamp(false, 2002, 2, 14, 16, 20, 47, 123456); + final int nano = (int) TimeUnit.MICROSECONDS.toNanos(123456); + final TimestampString tsx = new TimestampString(2002, 2, 14, 16, 20, 47).withNanos(nano); bitest(ts, rex.makeTimestampLiteral(tsx, 6)); } @@ -162,16 +162,16 @@ void tTimestampTZ() { @Test void tIntervalYearMonth() { - BigDecimal bd = new BigDecimal(3 * 12 + 5); // '3-5' year to month - RexLiteral intervalYearMonth = rex.makeIntervalLiteral(bd, YEAR_MONTH_INTERVAL); - IntervalYearLiteral intervalYearMonthExpr = ExpressionCreator.intervalYear(false, 3, 5); + final BigDecimal bd = new BigDecimal(3 * 12 + 5); // '3-5' year to month + final RexLiteral intervalYearMonth = rex.makeIntervalLiteral(bd, YEAR_MONTH_INTERVAL); + final IntervalYearLiteral intervalYearMonthExpr = ExpressionCreator.intervalYear(false, 3, 5); bitest(intervalYearMonthExpr, intervalYearMonth); } @Test void tIntervalYearMonthWithPrecision() { - BigDecimal bd = new BigDecimal(123 * 12 + 5); // '123-5' year to month - RexLiteral intervalYearMonth = + final BigDecimal bd = new BigDecimal(123 * 12 + 5); // '123-5' year to month + final RexLiteral intervalYearMonth = rex.makeIntervalLiteral( bd, new SqlIntervalQualifier( @@ -180,13 +180,13 @@ void tIntervalYearMonthWithPrecision() { org.apache.calcite.avatica.util.TimeUnit.MONTH, -1, SqlParserPos.QUOTED_ZERO)); - IntervalYearLiteral intervalYearMonthExpr = ExpressionCreator.intervalYear(false, 123, 5); + final IntervalYearLiteral intervalYearMonthExpr = ExpressionCreator.intervalYear(false, 123, 5); // rex --> expression assertEquals(intervalYearMonthExpr, intervalYearMonth.accept(rexExpressionConverter)); // expression -> rex - RexLiteral convertedRex = + final RexLiteral convertedRex = (RexLiteral) intervalYearMonthExpr.accept(expressionRexConverter, Context.newContext()); // Compare value only. Ignore the precision in SqlIntervalQualifier (which is used to parse @@ -199,14 +199,14 @@ void tIntervalYearMonthWithPrecision() { @Test void tIntervalMillisecond() { // Calcite stores milliseconds since Epoch, so test only millisecond precision - BigDecimal bd = + final BigDecimal bd = new BigDecimal( TimeUnit.DAYS.toMillis(3) + TimeUnit.HOURS.toMillis(5) + TimeUnit.MINUTES.toMillis(7) + TimeUnit.SECONDS.toMillis(9) + 500); // '3-5:7:9.500' day to second (6) - RexLiteral intervalDaySecond = + final RexLiteral intervalDaySecond = rex.makeIntervalLiteral( bd, new SqlIntervalQualifier( @@ -215,7 +215,7 @@ void tIntervalMillisecond() { org.apache.calcite.avatica.util.TimeUnit.SECOND, 3, SqlParserPos.ZERO)); - IntervalDayLiteral intervalDaySecondExpr = + final IntervalDayLiteral intervalDaySecondExpr = ExpressionCreator.intervalDay(false, 3, 5 * 3600 + 7 * 60 + 9, 500_000, 6); bitest(intervalDaySecondExpr, intervalDaySecond); } @@ -223,20 +223,20 @@ void tIntervalMillisecond() { @Test void tIntervalDay() { // Calcite always uses milliseconds - BigDecimal bd = new BigDecimal(TimeUnit.DAYS.toMillis(5)); - RexLiteral intervalDayLiteral = + final BigDecimal bd = new BigDecimal(TimeUnit.DAYS.toMillis(5)); + final RexLiteral intervalDayLiteral = rex.makeIntervalLiteral( bd, new SqlIntervalQualifier( org.apache.calcite.avatica.util.TimeUnit.DAY, -1, null, -1, SqlParserPos.ZERO)); - IntervalDayLiteral intervalDayExpr = ExpressionCreator.intervalDay(false, 5, 0, 0, 6); + final IntervalDayLiteral intervalDayExpr = ExpressionCreator.intervalDay(false, 5, 0, 0, 6); // rex --> expression - Expression convertedExpr = intervalDayLiteral.accept(rexExpressionConverter); + final Expression convertedExpr = intervalDayLiteral.accept(rexExpressionConverter); assertEquals(intervalDayExpr, convertedExpr); // expression -> rex - RexLiteral convertedRex = + final RexLiteral convertedRex = (RexLiteral) intervalDayExpr.accept(expressionRexConverter, Context.newContext()); // Compare value only. Ignore the precision in SqlIntervalQualifier in comparison. @@ -246,8 +246,8 @@ void tIntervalDay() { @Test void tIntervalYear() { - BigDecimal bd = new BigDecimal(123 * 12); // '123' year(3) - RexLiteral intervalYear = + final BigDecimal bd = new BigDecimal(123 * 12); // '123' year(3) + final RexLiteral intervalYear = rex.makeIntervalLiteral( bd, new SqlIntervalQualifier( @@ -256,12 +256,12 @@ void tIntervalYear() { null, -1, SqlParserPos.QUOTED_ZERO)); - IntervalYearLiteral intervalYearExpr = ExpressionCreator.intervalYear(false, 123, 0); + final IntervalYearLiteral intervalYearExpr = ExpressionCreator.intervalYear(false, 123, 0); // rex --> expression assertEquals(intervalYearExpr, intervalYear.accept(rexExpressionConverter)); // expression -> rex - RexLiteral convertedRex = + final RexLiteral convertedRex = (RexLiteral) intervalYearExpr.accept(expressionRexConverter, Context.newContext()); // Compare value only. Ignore the precision in SqlIntervalQualifier in comparison. @@ -272,8 +272,8 @@ void tIntervalYear() { @Test void tIntervalMonth() { - BigDecimal bd = new BigDecimal(123); // '123' month(3) - RexLiteral intervalMonth = + final BigDecimal bd = new BigDecimal(123); // '123' month(3) + final RexLiteral intervalMonth = rex.makeIntervalLiteral( bd, new SqlIntervalQualifier( @@ -282,13 +282,13 @@ void tIntervalMonth() { null, -1, SqlParserPos.QUOTED_ZERO)); - IntervalYearLiteral intervalMonthExpr = + final IntervalYearLiteral intervalMonthExpr = ExpressionCreator.intervalYear(false, 123 / 12, 123 % 12); // rex --> expression assertEquals(intervalMonthExpr, intervalMonth.accept(rexExpressionConverter)); // expression -> rex - RexLiteral convertedRex = + final RexLiteral convertedRex = (RexLiteral) intervalMonthExpr.accept(expressionRexConverter, Context.newContext()); // Compare value only. Ignore the precision in SqlIntervalQualifier in comparison. @@ -309,37 +309,37 @@ void tVarChar() { @Test void tDecimalLiteral() { - List decimalList = + final List decimalList = List.of( new BigDecimal("-123.457890"), new BigDecimal("123.457890"), new BigDecimal("123.450000"), new BigDecimal("-123.450000")); - for (BigDecimal bd : decimalList) { + for (final BigDecimal bd : decimalList) { bitest(ExpressionCreator.decimal(false, bd, 32, 6), c(bd, SqlTypeName.DECIMAL, 32, 6)); } } @Test void tDecimalLiteral2() { - List decimalList = + final List decimalList = List.of( new BigDecimal("-99.123456789123456789123456789123456789"), // scale = 36, precision =38 new BigDecimal("99.123456789123456789123456789123456789") // scale = 36, precision = 38 ); - for (BigDecimal bd : decimalList) { + for (final BigDecimal bd : decimalList) { bitest(ExpressionCreator.decimal(false, bd, 38, 36), c(bd, SqlTypeName.DECIMAL, 38, 36)); } } @Test void tDecimalUtil() { - long[] values = + final long[] values = new long[] {Long.MIN_VALUE, Integer.MIN_VALUE, 0, Integer.MAX_VALUE, Long.MAX_VALUE}; - for (long value : values) { - BigDecimal bd = BigDecimal.valueOf(value); - byte[] encoded = DecimalUtil.encodeDecimalIntoBytes(bd, 0, 16); - BigDecimal bd2 = DecimalUtil.getBigDecimalFromBytes(encoded, 0, 16); + for (final long value : values) { + final BigDecimal bd = BigDecimal.valueOf(value); + final byte[] encoded = DecimalUtil.encodeDecimalIntoBytes(bd, 0, 16); + final BigDecimal bd2 = DecimalUtil.getBigDecimalFromBytes(encoded, 0, 16); System.out.println(bd2); assertEquals(bd, bd2); } @@ -347,13 +347,13 @@ void tDecimalUtil() { @Test void tMap() { - ImmutableMap ss = + final ImmutableMap ss = ImmutableMap.of( ExpressionCreator.string(false, "foo"), ExpressionCreator.i32(false, 4), ExpressionCreator.string(false, "bar"), ExpressionCreator.i32(false, -1)); - RexNode calcite = + final RexNode calcite = rex.makeLiteral( ImmutableMap.of("foo", 4, "bar", -1), type.createMapType(t(SqlTypeName.VARCHAR), t(SqlTypeName.INTEGER)), @@ -387,20 +387,20 @@ void tStruct() { @Test void tFixedBinary() { - byte[] val = "my test".getBytes(StandardCharsets.UTF_8); + final byte[] val = "my test".getBytes(StandardCharsets.UTF_8); bitest( ExpressionCreator.fixedBinary(false, val), c(new org.apache.calcite.avatica.util.ByteString(val), SqlTypeName.BINARY)); } - public void test(Expression expression, RexNode rex) { + public void test(final Expression expression, final RexNode rex) { assertEquals(expression, rex.accept(new RexExpressionConverter())); } // bi-directional test : 1) rex -> substrait, substrait -> rex2. Compare rex == rex2 - public void bitest(Expression expression, RexNode rex) { + public void bitest(final Expression expression, final RexNode rex) { assertEquals(expression, rex.accept(rexExpressionConverter)); - RexNode convertedRex = expression.accept(expressionRexConverter, Context.newContext()); + final RexNode convertedRex = expression.accept(expressionRexConverter, Context.newContext()); assertEquals(rex, convertedRex); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java index 5e2ce7e53..6ec747e7e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteObjs.java @@ -12,7 +12,7 @@ public abstract class CalciteObjs { final RelDataTypeFactory type = SubstraitTypeSystem.TYPE_FACTORY; final RexBuilder rex = new RexBuilder(type); - RelDataType t(SqlTypeName typeName, int... vals) { + RelDataType t(final SqlTypeName typeName, final int... vals) { switch (vals.length) { case 0: return type.createSqlType(typeName); @@ -25,20 +25,20 @@ RelDataType t(SqlTypeName typeName, int... vals) { } } - RelDataType tN(SqlTypeName typeName, int... vals) { + RelDataType tN(final SqlTypeName typeName, final int... vals) { return type.createTypeWithNullability(t(typeName, vals), true); } public RexNode makeCalciteLiteral( - boolean nullable, SqlTypeName typeName, Object value, int... vals) { + final boolean nullable, final SqlTypeName typeName, final Object value, final int... vals) { return rex.makeLiteral(value, nullable ? tN(typeName, vals) : t(typeName, vals), true, false); } - public RexNode c(Object value, SqlTypeName typeName, int... vals) { + public RexNode c(final Object value, final SqlTypeName typeName, final int... vals) { return makeCalciteLiteral(false, typeName, value, vals); } - public RexNode cN(Object value, SqlTypeName typeName, int... vals) { + public RexNode cN(final Object value, final SqlTypeName typeName, final int... vals) { return makeCalciteLiteral(true, typeName, value, vals); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java index 176552f7d..5bb071136 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java @@ -27,7 +27,7 @@ class CalciteTypeTest extends CalciteObjs { new UserTypeMapper() { @Nullable @Override - public Type toSubstrait(RelDataType relDataType) { + public Type toSubstrait(final RelDataType relDataType) { if (uTypeFactory.isTypeFromFactory(relDataType)) { return uTypeFactory.createSubstrait(relDataType.isNullable()); } @@ -36,7 +36,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override - public RelDataType toCalcite(Type.UserDefined type) { + public RelDataType toCalcite(final Type.UserDefined type) { if (type.urn().equals(uTypeURI) && type.name().equals(uTypeName)) { return uTypeFactory.createCalcite(type.nullable()); } @@ -46,49 +46,49 @@ public RelDataType toCalcite(Type.UserDefined type) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void bool(boolean nullable) { + void bool(final boolean nullable) { testType(Type.withNullability(nullable).BOOLEAN, SqlTypeName.BOOLEAN, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void i8(boolean nullable) { + void i8(final boolean nullable) { testType(Type.withNullability(nullable).I8, SqlTypeName.TINYINT, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void i16(boolean nullable) { + void i16(final boolean nullable) { testType(Type.withNullability(nullable).I16, SqlTypeName.SMALLINT, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void i32(boolean nullable) { + void i32(final boolean nullable) { testType(Type.withNullability(nullable).I32, SqlTypeName.INTEGER, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void i64(boolean nullable) { + void i64(final boolean nullable) { testType(Type.withNullability(nullable).I64, SqlTypeName.BIGINT, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void fp32(boolean nullable) { + void fp32(final boolean nullable) { testType(Type.withNullability(nullable).FP32, SqlTypeName.REAL, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void fp64(boolean nullable) { + void fp64(final boolean nullable) { testType(Type.withNullability(nullable).FP64, SqlTypeName.DOUBLE, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void calciteFloatToFp64(boolean nullable) { + void calciteFloatToFp64(final boolean nullable) { assertEquals( Type.withNullability(nullable).FP64, TypeConverter.DEFAULT.toSubstrait( @@ -97,20 +97,20 @@ void calciteFloatToFp64(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void date(boolean nullable) { + void date(final boolean nullable) { testType(Type.withNullability(nullable).DATE, SqlTypeName.DATE, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void time(boolean nullable) { + void time(final boolean nullable) { testType(Type.withNullability(nullable).TIME, SqlTypeName.TIME, nullable, 6); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void precisionTimeStamp(boolean nullable) { - for (int precision : new int[] {0, 3, 6}) { + void precisionTimeStamp(final boolean nullable) { + for (final int precision : new int[] {0, 3, 6}) { testType( Type.withNullability(nullable).precisionTimestamp(precision), SqlTypeName.TIMESTAMP, @@ -121,8 +121,8 @@ void precisionTimeStamp(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void precisionTimestamptz(boolean nullable) { - for (int precision : new int[] {0, 3, 6}) { + void precisionTimestamptz(final boolean nullable) { + for (final int precision : new int[] {0, 3, 6}) { testType( Type.withNullability(nullable).precisionTimestampTZ(precision), SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, @@ -133,7 +133,7 @@ void precisionTimestamptz(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void intervalYear(boolean nullable) { + void intervalYear(final boolean nullable) { testType( Type.withNullability(nullable).INTERVAL_YEAR, type.createSqlIntervalType(SubstraitTypeSystem.YEAR_MONTH_INTERVAL), @@ -142,7 +142,7 @@ void intervalYear(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void intervalDay(boolean nullable) { + void intervalDay(final boolean nullable) { testType( Type.withNullability(nullable).intervalDay(6), type.createSqlIntervalType(SubstraitTypeSystem.DAY_SECOND_INTERVAL), @@ -151,43 +151,43 @@ void intervalDay(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void string(boolean nullable) { + void string(final boolean nullable) { testType(Type.withNullability(nullable).STRING, SqlTypeName.VARCHAR, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void binary(boolean nullable) { + void binary(final boolean nullable) { testType(Type.withNullability(nullable).BINARY, SqlTypeName.VARBINARY, nullable); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void fixedBinary(boolean nullable) { + void fixedBinary(final boolean nullable) { testType(Type.withNullability(nullable).fixedBinary(74), SqlTypeName.BINARY, nullable, 74); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void fixedChar(boolean nullable) { + void fixedChar(final boolean nullable) { testType(Type.withNullability(nullable).fixedChar(74), SqlTypeName.CHAR, nullable, 74); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void varchar(boolean nullable) { + void varchar(final boolean nullable) { testType(Type.withNullability(nullable).varChar(74), SqlTypeName.VARCHAR, nullable, 74); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void decimal(boolean nullable) { + void decimal(final boolean nullable) { testType(Type.withNullability(nullable).decimal(38, 13), SqlTypeName.DECIMAL, nullable, 38, 13); } @ParameterizedTest @ValueSource(booleans = {true, false}) - void list(boolean nullable) { + void list(final boolean nullable) { testType( Type.withNullability(nullable).list(TypeCreator.REQUIRED.I16), type.createArrayType(type.createSqlType(SqlTypeName.SMALLINT), -1), @@ -196,7 +196,7 @@ void list(boolean nullable) { @ParameterizedTest @ValueSource(booleans = {true, false}) - void map(boolean nullable) { + void map(final boolean nullable) { testType( Type.withNullability(nullable).map(TypeCreator.REQUIRED.STRING, TypeCreator.REQUIRED.I8), type.createMapType( @@ -242,46 +242,57 @@ void nestedStruct() { @ParameterizedTest @ValueSource(booleans = {true, false}) - void userDefinedType(boolean nullable) { - Type type = uTypeFactory.createSubstrait(nullable); + void userDefinedType(final boolean nullable) { + final Type type = uTypeFactory.createSubstrait(nullable); testType(typeConverter, type, uTypeFactory.createCalcite(nullable), null); } - private void testType(TypeExpression expression, SqlTypeName typeName, boolean nullable) { + private void testType( + final TypeExpression expression, final SqlTypeName typeName, final boolean nullable) { testType(expression, type.createTypeWithNullability(type.createSqlType(typeName), nullable)); } private void testType( - TypeExpression expression, SqlTypeName typeName, boolean nullable, int prec) { + final TypeExpression expression, + final SqlTypeName typeName, + final boolean nullable, + final int prec) { testType( expression, type.createTypeWithNullability(type.createSqlType(typeName, prec), nullable)); } private void testType( - TypeExpression expression, SqlTypeName typeName, boolean nullable, int prec, int scale) { + final TypeExpression expression, + final SqlTypeName typeName, + final boolean nullable, + final int prec, + final int scale) { testType( expression, type.createTypeWithNullability(type.createSqlType(typeName, prec, scale), nullable)); } - private void testType(TypeExpression expression, RelDataType calciteType) { + private void testType(final TypeExpression expression, final RelDataType calciteType) { testType(expression, calciteType, null); } - private void testType(TypeExpression expression, RelDataType calciteType, boolean nullable) { + private void testType( + final TypeExpression expression, final RelDataType calciteType, final boolean nullable) { testType(expression, type.createTypeWithNullability(calciteType, nullable)); } private void testType( - TypeExpression expression, RelDataType calciteType, List dfsFieldNames) { + final TypeExpression expression, + final RelDataType calciteType, + final List dfsFieldNames) { testType(TypeConverter.DEFAULT, expression, calciteType, dfsFieldNames); } private void testType( - TypeConverter converter, - TypeExpression expression, - RelDataType calciteType, - List dfsFieldNames) { + final TypeConverter converter, + final TypeExpression expression, + final RelDataType calciteType, + final List dfsFieldNames) { assertEquals(expression, converter.toSubstrait(calciteType)); assertEquals(calciteType, converter.toCalcite(type, expression, dfsFieldNames)); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java index d15764d72..4fa6f6bfd 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComparisonFunctionsTest.java @@ -11,39 +11,39 @@ class ComparisonFunctionsTest extends PlanTestBase { @Test void is_true() throws Exception { - String query = "SELECT ((int_a > int_b) IS TRUE) FROM numbers"; + final String query = "SELECT ((int_a > int_b) IS TRUE) FROM numbers"; assertSqlSubstraitRelRoundTrip(query, CREATES); } @Test void is_false() throws Exception { - String query = "SELECT ((int_a > int_b) IS FALSE) FROM numbers"; + final String query = "SELECT ((int_a > int_b) IS FALSE) FROM numbers"; assertSqlSubstraitRelRoundTrip(query, CREATES); } @Test void is_not_true() throws Exception { - String query = "SELECT ((int_a > int_b) IS NOT TRUE) FROM numbers"; + final String query = "SELECT ((int_a > int_b) IS NOT TRUE) FROM numbers"; assertSqlSubstraitRelRoundTrip(query, CREATES); } @Test void is_not_false() throws Exception { - String query = "SELECT ((int_a > int_b) IS NOT FALSE) FROM numbers"; + final String query = "SELECT ((int_a > int_b) IS NOT FALSE) FROM numbers"; assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @CsvSource({"int_a, int_b", "int_b, int_a", "double_a, double_b", "double_b, double_a"}) - void is_distinct_from(String left, String right) throws Exception { - String query = String.format("SELECT (%s IS DISTINCT FROM %s) FROM numbers", left, right); + void is_distinct_from(final String left, final String right) throws Exception { + final String query = String.format("SELECT (%s IS DISTINCT FROM %s) FROM numbers", left, right); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"int_a", "int_b", "double_a", "double_b"}) - void is_distinct_from_null_vs_col(String column) throws Exception { - String query = String.format("SELECT (NULL IS DISTINCT FROM %s) FROM numbers", column); + void is_distinct_from_null_vs_col(final String column) throws Exception { + final String query = String.format("SELECT (NULL IS DISTINCT FROM %s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @@ -55,9 +55,9 @@ void is_distinct_from_null_vs_col(String column) throws Exception { "int_a, int_b, double_a", "CAST(NULL AS INT), int_a, int_b" }) - void least(String args) throws Exception { - String join_args = String.join(", ", args); - String query = String.format("SELECT LEAST(%s) FROM numbers", join_args); + void least(final String args) throws Exception { + final String join_args = String.join(", ", args); + final String query = String.format("SELECT LEAST(%s) FROM numbers", join_args); assertSqlSubstraitRelRoundTrip(query, CREATES); } @@ -69,9 +69,9 @@ void least(String args) throws Exception { "int_a, int_b, double_a", "CAST(NULL AS INT), int_a, int_b" }) - void greatest(String args) throws Exception { - String join_args = String.join(", ", args); - String query = String.format("SELECT LEAST(%s) FROM numbers", join_args); + void greatest(final String args) throws Exception { + final String join_args = String.join(", ", args); + final String query = String.format("SELECT LEAST(%s) FROM numbers", join_args); assertSqlSubstraitRelRoundTrip(query, CREATES); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java index dc3cfdbbc..529d85482 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -27,12 +27,14 @@ class ComplexAggregateTest extends PlanTestBase { private Aggregate.Grouping emptyGrouping = Aggregate.Grouping.builder().build(); - Aggregate.Measure withPreMeasureFilter(Aggregate.Measure measure, Expression preMeasureFilter) { + Aggregate.Measure withPreMeasureFilter( + final Aggregate.Measure measure, final Expression preMeasureFilter) { return Aggregate.Measure.builder().from(measure).preMeasureFilter(preMeasureFilter).build(); } - Aggregate.Measure withSort(Aggregate.Measure measure, List sortFields) { - ImmutableAggregateFunctionInvocation afi = + Aggregate.Measure withSort( + final Aggregate.Measure measure, final List sortFields) { + final ImmutableAggregateFunctionInvocation afi = AggregateFunctionInvocation.builder().from(measure.getFunction()).sort(sortFields).build(); return Aggregate.Measure.builder().from(measure).function(afi).build(); } @@ -49,8 +51,9 @@ Aggregate.Measure withSort(Aggregate.Measure measure, List * @param pojo a pojo that requires transformation for use in Calcite * @param expectedTransform the expected transformation output */ - protected void validateAggregateTransformation(Aggregate pojo, Rel expectedTransform) { - Aggregate converterPojo = + protected void validateAggregateTransformation( + final Aggregate pojo, final Rel expectedTransform) { + final Aggregate converterPojo = PreCalciteAggregateValidator.PreCalciteAggregateTransformer .transformToValidCalciteAggregate(pojo); assertEquals(expectedTransform, converterPojo); @@ -62,13 +65,13 @@ protected void validateAggregateTransformation(Aggregate pojo, Rel expectedTrans @Test void handleComplexMeasureArgument() { // SELECT sum(c + 7) FROM example - Aggregate rel = + final Aggregate rel = b.aggregate( input -> emptyGrouping, input -> List.of(b.sum(b.add(b.fieldReference(input, 2), b.i32(7)))), table); - Aggregate expectedFinal = + final Aggregate expectedFinal = b.aggregate( input -> emptyGrouping, // sum call references input field @@ -84,7 +87,7 @@ void handleComplexMeasureArgument() { @Test void handleComplexPreMeasureFilter() { // SELECT sum(a) FILTER (b = 42) FROM example - Aggregate rel = + final Aggregate rel = b.aggregate( input -> emptyGrouping, input -> @@ -93,7 +96,7 @@ void handleComplexPreMeasureFilter() { b.sum(input, 0), b.equal(b.fieldReference(input, 1), b.i32(42)))), table); - Aggregate expectedFinal = + final Aggregate expectedFinal = b.aggregate( input -> emptyGrouping, input -> List.of(withPreMeasureFilter(b.sum(input, 0), b.fieldReference(input, 4))), @@ -105,7 +108,7 @@ void handleComplexPreMeasureFilter() { @Test void handleComplexSortingArguments() { // SELECT sum(d ORDER BY -b ASC) FROM example - Aggregate rel = + final Aggregate rel = b.aggregate( input -> emptyGrouping, input -> @@ -118,7 +121,7 @@ void handleComplexSortingArguments() { Expression.SortDirection.ASC_NULLS_FIRST)))), table); - Aggregate expectedFinal = + final Aggregate expectedFinal = b.aggregate( input -> emptyGrouping, input -> @@ -139,7 +142,7 @@ void handleComplexSortingArguments() { @Test void handleComplexGroupingArgument() { - Aggregate rel = + final Aggregate rel = b.aggregate( input -> b.grouping( @@ -147,7 +150,7 @@ void handleComplexGroupingArgument() { input -> List.of(), table); - Aggregate expectedFinal = + final Aggregate expectedFinal = b.aggregate( // grouping exprs are now field references to input input -> b.grouping(input, 4, 5), @@ -163,9 +166,10 @@ void handleComplexGroupingArgument() { @Test void handleOutOfOrderGroupingArguments() { - Aggregate rel = b.aggregate(input -> b.grouping(input, 1, 0, 2), input -> List.of(), table); + final Aggregate rel = + b.aggregate(input -> b.grouping(input, 1, 0, 2), input -> List.of(), table); - Aggregate expectedFinal = + final Aggregate expectedFinal = b.aggregate( // grouping exprs are now field references to input input -> b.grouping(input, 4, 5, 6), @@ -184,12 +188,12 @@ void handleOutOfOrderGroupingArguments() { @Test void outOfOrderGroupingKeysHaveCorrectCalciteType() { - Rel rel = + final Rel rel = b.aggregate( input -> b.grouping(input, 2, 0), input -> List.of(), b.namedScan(List.of("foo"), List.of("a", "b", "c"), List.of(R.I64, R.I64, R.STRING))); - RelNode relNode = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(rel); + final RelNode relNode = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(rel); assertRowMatch(relNode.getRowType(), R.STRING, R.I64); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index afe088d30..c74643c44 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -31,15 +31,15 @@ class ComplexSortTest extends PlanTestBase { * information. A {@link RelNode} is only annotated if its {@link RelCollation} is not empty. */ public static class CollationRelWriter extends RelWriterImpl { - public CollationRelWriter(StringWriter sw) { + public CollationRelWriter(final StringWriter sw) { super(new PrintWriter(sw), SqlExplainLevel.EXPPLAN_ATTRIBUTES, false); } @Override - protected void explain_(RelNode rel, List> values) { - RelCollation collation = rel.getTraitSet().getCollation(); + protected void explain_(final RelNode rel, final List> values) { + final RelCollation collation = rel.getTraitSet().getCollation(); if (!collation.isDefault()) { - StringBuilder s = new StringBuilder(); + final StringBuilder s = new StringBuilder(); spacer.spaces(s); s.append("Collation: ").append(collation.toString()); pw.println(s); @@ -53,7 +53,7 @@ void handleInputReferenceSort() { // CREATE TABLE example (a VARCHAR) // SELECT a FROM example ORDER BY a - Rel rel = + final Rel rel = b.project( input -> b.fieldReferences(input, 0), b.remap(1), @@ -64,13 +64,13 @@ void handleInputReferenceSort() { b.fieldReference(input, 0), Expression.SortDirection.ASC_NULLS_LAST)), b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); - String expected = + final String expected = "Collation: [0]\n" + "LogicalSort(sort0=[$0], dir0=[ASC])\n" + " LogicalTableScan(table=[[example]])\n"; - RelNode relReturned = substraitToCalcite.convert(rel); - StringWriter sw = new StringWriter(); + final RelNode relReturned = substraitToCalcite.convert(rel); + final StringWriter sw = new StringWriter(); relReturned.explain(new CollationRelWriter(sw)); assertEquals(expected, sw.toString()); } @@ -80,7 +80,7 @@ void handleCastExpressionSort() { // CREATE TABLE example (a VARCHAR) // SELECT a FROM example ORDER BY a::INT - Rel rel = + final Rel rel = b.project( input -> b.fieldReferences(input, 0), b.remap(1), @@ -92,15 +92,15 @@ void handleCastExpressionSort() { Expression.SortDirection.ASC_NULLS_LAST)), b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); - String expected = + final String expected = "LogicalProject(a0=[$0])\n" + " Collation: [1]\n" + " LogicalSort(sort0=[$1], dir0=[ASC])\n" + " LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])\n" + " LogicalTableScan(table=[[example]])\n"; - RelNode relReturned = substraitToCalcite.convert(rel); - StringWriter sw = new StringWriter(); + final RelNode relReturned = substraitToCalcite.convert(rel); + final StringWriter sw = new StringWriter(); relReturned.explain(new CollationRelWriter(sw)); assertEquals(expected, sw.toString()); } @@ -110,7 +110,7 @@ void handleCastProjectAndSortWithSortDirection() { // CREATE TABLE example (a VARCHAR) // SELECT a::INT FROM example ORDER BY a::INT DESC NULLS LAST - Rel rel = + final Rel rel = b.project( input -> List.of(b.cast(b.fieldReference(input, 0), R.I32)), b.remap(1), @@ -122,15 +122,15 @@ void handleCastProjectAndSortWithSortDirection() { Expression.SortDirection.DESC_NULLS_LAST)), b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); - String expected = + final String expected = "LogicalProject(a0=[CAST($0):INTEGER NOT NULL])\n" + " Collation: [1 DESC-nulls-last]\n" + " LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])\n" + " LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])\n" + " LogicalTableScan(table=[[example]])\n"; - RelNode relReturned = substraitToCalcite.convert(rel); - StringWriter sw = new StringWriter(); + final RelNode relReturned = substraitToCalcite.convert(rel); + final StringWriter sw = new StringWriter(); relReturned.explain(new CollationRelWriter(sw)); assertEquals(expected, sw.toString()); } @@ -140,7 +140,7 @@ void handleCastSortToOriginalType() { // CREATE TABLE example (a VARCHAR) // SELECT a FROM example ORDER BY a::VARCHAR - Rel rel = + final Rel rel = b.project( input -> List.of(b.fieldReference(input, 0)), b.remap(1), @@ -152,15 +152,15 @@ void handleCastSortToOriginalType() { Expression.SortDirection.DESC_NULLS_LAST)), b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING)))); - String expected = + final String expected = "LogicalProject(a0=[$0])\n" + " Collation: [1 DESC-nulls-last]\n" + " LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])\n" + " LogicalProject(a=[$0], a0=[$0])\n" + " LogicalTableScan(table=[[example]])\n"; - RelNode relReturned = substraitToCalcite.convert(rel); - StringWriter sw = new StringWriter(); + final RelNode relReturned = substraitToCalcite.convert(rel); + final StringWriter sw = new StringWriter(); relReturned.explain(new CollationRelWriter(sw)); assertEquals(expected, sw.toString()); } @@ -170,7 +170,7 @@ void handleComplex2ExpressionSort() { // CREATE TABLE example (a VARCHAR, b INT) // SELECT b, a FROM example ORDER BY a::INT DESC, -b + 42 ASC NULLS LAST - Rel rel = + final Rel rel = b.project( input -> List.of(b.fieldReference(input, 0), b.fieldReference(input, 1)), b.remap(2, 3), @@ -185,15 +185,15 @@ void handleComplex2ExpressionSort() { Expression.SortDirection.ASC_NULLS_LAST)), b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32)))); - String expected = + final String expected = "LogicalProject(a0=[$0], b0=[$1])\n" + " Collation: [2 DESC, 3]\n" + " LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC])\n" + " LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)])\n" + " LogicalTableScan(table=[[example]])\n"; - RelNode relReturned = substraitToCalcite.convert(rel); - StringWriter sw = new StringWriter(); + final RelNode relReturned = substraitToCalcite.convert(rel); + final StringWriter sw = new StringWriter(); relReturned.explain(new CollationRelWriter(sw)); assertEquals(expected, sw.toString()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 34e06d0ac..94906362b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -49,7 +49,7 @@ class CustomFunctionTest extends PlanTestBase { static { try { FUNCTIONS_CUSTOM = asString("extensions/functions_custom.yaml"); - } catch (IOException e) { + } catch (final IOException e) { throw new UncheckedIOException(e); } } @@ -71,7 +71,7 @@ class CustomFunctionTest extends PlanTestBase { new UserTypeMapper() { @Nullable @Override - public Type toSubstrait(RelDataType relDataType) { + public Type toSubstrait(final RelDataType relDataType) { if (aTypeFactory.isTypeFromFactory(relDataType)) { return TypeCreator.of(relDataType.isNullable()).userDefined(URN, aTypeName); } @@ -83,7 +83,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override - public RelDataType toCalcite(Type.UserDefined type) { + public RelDataType toCalcite(final Type.UserDefined type) { if (type.urn().equals(URN)) { if (type.name().equals(aTypeName)) { return aTypeFactory.createCalcite(type.nullable()); @@ -269,14 +269,15 @@ public RelDataType toCalcite(Type.UserDefined type) { class CustomSubstraitToCalcite extends SubstraitToCalcite { public CustomSubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { + final SimpleExtension.ExtensionCollection extensions, + final RelDataTypeFactory typeFactory, + final TypeConverter typeConverter) { super(extensions, typeFactory, typeConverter); } @Override - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { + protected SubstraitRelNodeConverter createSubstraitRelNodeConverter( + final RelBuilder relBuilder) { return new SubstraitRelNodeConverter( typeFactory, relBuilder, @@ -291,21 +292,21 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r void customScalarFunctionRoundtrip() { // CREATE TABLE example(a TEXT) // SELECT custom_scalar(a) FROM example - Rel rel = + final Rel rel = b.project( input -> List.of(b.scalarFn(URN, "custom_scalar:str", R.STRING, b.fieldReference(input, 0))), b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarAnyFunctionRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -313,14 +314,14 @@ void customScalarAnyFunctionRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarAnyToAnyFunctionRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -329,14 +330,14 @@ void customScalarAnyToAnyFunctionRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(R.FP64))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarAny1Any1ToAny1FunctionRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -349,14 +350,14 @@ void customScalarAny1Any1ToAny1FunctionRoundtrip() { b.remap(2), b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.FP64))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarAny1Any1ToAny1FunctionMismatch() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -372,7 +373,7 @@ void customScalarAny1Any1ToAny1FunctionMismatch() { assertThrows( IllegalArgumentException.class, () -> { - RelNode calciteRel = substraitToCalcite.convert(rel); + final RelNode calciteRel = substraitToCalcite.convert(rel); calciteToSubstrait.apply(calciteRel); }, "Unable to convert call custom_scalar_any1any1_to_any1(fp64, string)"); @@ -380,7 +381,7 @@ void customScalarAny1Any1ToAny1FunctionMismatch() { @Test void customScalarAny1Any2ToAny2FunctionRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -393,14 +394,14 @@ void customScalarAny1Any2ToAny2FunctionRoundtrip() { b.remap(2), b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.FP64, R.STRING))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListAnyRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -412,14 +413,14 @@ void customScalarListAnyRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.I64)))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListAnyAndAnyRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -433,14 +434,14 @@ void customScalarListAnyAndAnyRoundtrip() { b.namedScan( List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListStringRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -452,14 +453,14 @@ void customScalarListStringRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListStringAndAnyRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -473,14 +474,14 @@ void customScalarListStringAndAnyRoundtrip() { b.namedScan( List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListStringAndAnyVariadic0Roundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -498,14 +499,14 @@ void customScalarListStringAndAnyVariadic0Roundtrip() { List.of("a", "b", "c", "d"), List.of(R.list(R.STRING), R.STRING, R.STRING, R.STRING))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListStringAndAnyVariadic0NoArgsRoundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -517,14 +518,14 @@ void customScalarListStringAndAnyVariadic0NoArgsRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(R.list(R.STRING)))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customScalarListStringAndAnyVariadic1Roundtrip() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -538,8 +539,8 @@ void customScalarListStringAndAnyVariadic1Roundtrip() { b.namedScan( List.of("example"), List.of("a", "b"), List.of(R.list(R.STRING), R.STRING))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @@ -547,7 +548,7 @@ void customScalarListStringAndAnyVariadic1Roundtrip() { void customAggregateFunctionRoundtrip() { // CREATE TABLE example (a BIGINT) // SELECT custom_aggregate(a) FROM example GROUP BY a - Rel rel = + final Rel rel = b.aggregate( input -> b.grouping(input, 0), input -> @@ -557,8 +558,8 @@ void customAggregateFunctionRoundtrip() { URN, "custom_aggregate:i64", R.I64, b.fieldReference(input, 0)))), b.namedScan(List.of("example"), List.of("a"), List.of(R.I64))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @@ -566,7 +567,7 @@ void customAggregateFunctionRoundtrip() { void customTypesInFunctionsRoundtrip() { // CREATE TABLE example(a a_type) // SELECT to_b_type(a) FROM example - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -578,31 +579,32 @@ void customTypesInFunctionsRoundtrip() { b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - RelNode calciteRel = substraitToCalcite.convert(rel); - Rel relReturned = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel); + final Rel relReturned = calciteToSubstrait.apply(calciteRel); assertEquals(rel, relReturned); } @Test void customTypesLiteralInFunctionsRoundtrip() { - Builder bldr = Expression.Literal.newBuilder(); - Any anyValue = Any.pack(bldr.setI32(10).build()); - UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); + final Builder bldr = Expression.Literal.newBuilder(); + final Any anyValue = Any.pack(bldr.setI32(10).build()); + final UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); - Rel rel1 = + final Rel rel1 = b.project( input -> List.of(b.scalarFn(URN, "to_b_type:u!a_type", R.userDefined(URN, "b_type"), val)), b.remap(1), b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - RelNode calciteRel = substraitToCalcite.convert(rel1); - Rel rel2 = calciteToSubstrait.apply(calciteRel); + final RelNode calciteRel = substraitToCalcite.convert(rel1); + final Rel rel2 = calciteToSubstrait.apply(calciteRel); assertEquals(rel1, rel2); - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); + final ExtensionCollector extensionCollector = new ExtensionCollector(); + final io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); + final Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); assertEquals(rel1, rel3); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java index 1f14d9459..886809f11 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java @@ -17,13 +17,13 @@ public DdlRoundtripTest() throws SqlParseException { @Test void testCreateTable() throws Exception { - String sql = "create table dst1 as select * from src1"; + final String sql = "create table dst1 as select * from src1"; assertFullRoundTripWithIdentityProjectionWorkaround(sql, catalogReader); } @Test void testCreateView() throws Exception { - String sql = "create view dst1 as select * from src1"; + final String sql = "create view dst1 as select * from src1"; assertFullRoundTripWithIdentityProjectionWorkaround(sql, catalogReader); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java b/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java index 1d8cca4cc..8ac4d3ebf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java @@ -28,8 +28,8 @@ class DuplicateFunctionUrnTest extends PlanTestBase { static { try { - String extensions1 = asString("extensions/functions_duplicate_urn1.yaml"); - String extensions2 = asString("extensions/functions_duplicate_urn2.yaml"); + final String extensions1 = asString("extensions/functions_duplicate_urn1.yaml"); + final String extensions2 = asString("extensions/functions_duplicate_urn2.yaml"); collection1 = SimpleExtension.load("urn:extension:io.substrait:functions_string", extensions1); collection2 = SimpleExtension.load("urn:extension:com.domain:string", extensions2); @@ -37,7 +37,7 @@ class DuplicateFunctionUrnTest extends PlanTestBase { // Verify that the merged collection contains duplicate concat functions with different URNs // This is a precondition for the tests - if this fails, the tests don't make sense - List concatFunctions = + final List concatFunctions = collection.scalarFunctions().stream().filter(f -> f.name().equals("concat")).toList(); if (concatFunctions.size() != 2) { @@ -46,13 +46,13 @@ class DuplicateFunctionUrnTest extends PlanTestBase { + concatFunctions.size()); } - String urn1 = concatFunctions.get(0).getAnchor().urn(); - String urn2 = concatFunctions.get(1).getAnchor().urn(); + final String urn1 = concatFunctions.get(0).getAnchor().urn(); + final String urn2 = concatFunctions.get(1).getAnchor().urn(); if (urn1.equals(urn2)) { throw new IllegalStateException( "Expected different URNs for the two concat functions, but both were: " + urn1); } - } catch (IOException e) { + } catch (final IOException e) { throw new UncheckedIOException(e); } } @@ -82,14 +82,14 @@ void testMergeOrderDeterminesFunctionPrecedence() { // The FunctionConverter uses a "last-wins" strategy: the last function added to the // extension collection will be matched when converting from Calcite to Substrait. - SimpleExtension.ExtensionCollection reverseCollection = collection2.merge(collection1); - ScalarFunctionConverter converterA = + final SimpleExtension.ExtensionCollection reverseCollection = collection2.merge(collection1); + final ScalarFunctionConverter converterA = new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory); - ScalarFunctionConverter converterB = + final ScalarFunctionConverter converterB = new ScalarFunctionConverter(reverseCollection.scalarFunctions(), typeFactory); - RexBuilder rexBuilder = new RexBuilder(typeFactory); - RexCall concatCall = + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RexCall concatCall = (RexCall) rexBuilder.makeCall( SqlStdOperatorTable.CONCAT, @@ -97,20 +97,22 @@ void testMergeOrderDeterminesFunctionPrecedence() { rexBuilder.makeLiteral("world")); // Create a simple topLevelConverter that converts literals to Substrait expressions - java.util.function.Function topLevelConverter = + final java.util.function.Function topLevelConverter = rexNode -> { - org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; + final org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; return Expression.StrLiteral.builder() .value(lit.getValueAs(String.class)) .nullable(false) .build(); }; - Optional exprA = converterA.convert(concatCall, topLevelConverter); - Optional exprB = converterB.convert(concatCall, topLevelConverter); + final Optional exprA = converterA.convert(concatCall, topLevelConverter); + final Optional exprB = converterB.convert(concatCall, topLevelConverter); - Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get(); - Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get(); + final Expression.ScalarFunctionInvocation funcA = + (Expression.ScalarFunctionInvocation) exprA.get(); + final Expression.ScalarFunctionInvocation funcB = + (Expression.ScalarFunctionInvocation) exprB.get(); assertEquals( "extension:com.domain:string", @@ -131,19 +133,19 @@ void testLtrimMergeOrderWithDefaultExtensions() { // The FunctionConverter uses a "last-wins" strategy. // Merge default extensions with collection2 - collection2's ltrim should be last - SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2); + final SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2); // Merge collection2 with default extensions - default ltrim should be last - SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions); + final SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions); - ScalarFunctionConverter converterA = + final ScalarFunctionConverter converterA = new ScalarFunctionConverter(defaultWithCustom.scalarFunctions(), typeFactory); - ScalarFunctionConverter converterB = + final ScalarFunctionConverter converterB = new ScalarFunctionConverter(customWithDefault.scalarFunctions(), typeFactory); // Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim - RexBuilder rexBuilder = new RexBuilder(typeFactory); - RexCall trimCall = + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RexCall trimCall = (RexCall) rexBuilder.makeCall( SqlStdOperatorTable.TRIM, @@ -151,10 +153,10 @@ void testLtrimMergeOrderWithDefaultExtensions() { rexBuilder.makeLiteral(" "), rexBuilder.makeLiteral("test")); - java.util.function.Function topLevelConverter = + final java.util.function.Function topLevelConverter = rexNode -> { - org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; - Object value = lit.getValue(); + final org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode; + final Object value = lit.getValue(); if (value == null) { return Expression.StrLiteral.builder().value("").nullable(true).build(); } @@ -162,17 +164,19 @@ void testLtrimMergeOrderWithDefaultExtensions() { return Expression.StrLiteral.builder().value(value.toString()).nullable(false).build(); }; - Optional exprA = converterA.convert(trimCall, topLevelConverter); - Optional exprB = converterB.convert(trimCall, topLevelConverter); + final Optional exprA = converterA.convert(trimCall, topLevelConverter); + final Optional exprB = converterB.convert(trimCall, topLevelConverter); - Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get(); + final Expression.ScalarFunctionInvocation funcA = + (Expression.ScalarFunctionInvocation) exprA.get(); // converterA should use collection2's custom ltrim (last) assertEquals( "extension:com.domain:string", funcA.declaration().getAnchor().urn(), "converterA should use last ltrim (custom from collection2)"); - Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get(); + final Expression.ScalarFunctionInvocation funcB = + (Expression.ScalarFunctionInvocation) exprB.get(); // converterB should use default extensions' ltrim (last) assertEquals( "extension:io.substrait:functions_string", diff --git a/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java index 20237e0f4..c6760267a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/EmptyArrayLiteralTest.java @@ -17,9 +17,9 @@ class EmptyArrayLiteralTest extends PlanTestBase { @Test void emptyArrayLiteral() { - Type colType = N.I8; - EmptyListLiteral emptyListLiteral = ExpressionCreator.emptyList(false, N.I8); - Project rel = + final Type colType = N.I8; + final EmptyListLiteral emptyListLiteral = ExpressionCreator.emptyList(false, N.I8); + final Project rel = b.project( input -> List.of(emptyListLiteral), Rel.Remap.offset(1, 1), @@ -29,9 +29,9 @@ void emptyArrayLiteral() { @Test void nullableEmptyArrayLiteral() { - Type colType = N.I8; - EmptyListLiteral emptyListLiteral = ExpressionCreator.emptyList(true, N.I8); - Project rel = + final Type colType = N.I8; + final EmptyListLiteral emptyListLiteral = ExpressionCreator.emptyList(true, N.I8); + final Project rel = b.project( input -> List.of(emptyListLiteral), Rel.Remap.offset(1, 1), diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java index db04ffc28..de6ea00a8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -72,9 +72,10 @@ void inPredicate() throws IOException, SqlParseException { @Test void singleOrList() { - Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10)); - RexNode rexNode = singleOrList.accept(converter, Context.newContext()); - Expression substraitExpression = + final Expression singleOrList = + b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10)); + final RexNode rexNode = singleOrList.accept(converter, Context.newContext()); + final Expression substraitExpression = rexNode.accept( new RexExpressionConverter( CREATE_SEARCH_CONV.apply(rexBuilder), @@ -90,13 +91,13 @@ void singleOrList() { @Test void switchExpression() { - Expression switchExpression = + final Expression switchExpression = b.switchExpression( b.fieldReference(commonTable, 0), List.of(b.switchClause(b.i32(5), b.i32(1)), b.switchClause(b.i32(10), b.i32(2))), b.i32(3)); - RexNode rexNode = switchExpression.accept(converter, Context.newContext()); - Expression expression = + final RexNode rexNode = switchExpression.accept(converter, Context.newContext()); + final Expression expression = rexNode.accept( new RexExpressionConverter( CASE, new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory))); @@ -113,7 +114,7 @@ void switchExpression() { @Test void castFailureCondition() { - Rel rel = + final Rel rel = b.project( input -> List.of( @@ -129,7 +130,7 @@ void castFailureCondition() { assertFullRoundTrip(rel); } - void assertExpressionEquality(Expression expected, Expression actual) { + void assertExpressionEquality(final Expression expected, final Expression actual) { // go the extra mile and convert both inputs to protobuf // helps verify that the protobuf conversion is not broken assertEquals( @@ -144,8 +145,8 @@ void supportedPrecisionForPrecisionTimestampLiteral() { assertPrecisionTimestampLiteral(6); } - void assertPrecisionTimestampLiteral(int precision) { - RexNode calciteExpr = + void assertPrecisionTimestampLiteral(final int precision) { + final RexNode calciteExpr = Expression.PrecisionTimestampLiteral.builder() .value(0) .precision(precision) @@ -161,8 +162,8 @@ void supportedPrecisionForPrecisionTimestampTZLiteral() { assertPrecisionTimestampTZLiteral(6); } - void assertPrecisionTimestampTZLiteral(int precision) { - RexNode calciteExpr = + void assertPrecisionTimestampTZLiteral(final int precision) { + final RexNode calciteExpr = Expression.PrecisionTimestampTZLiteral.builder() .value(0) .precision(precision) @@ -197,7 +198,7 @@ void unsupportedPrecisionForPrecisionTimestampLiteral() { assertThrowsUnsupportedPrecisionPrecisionTimestampLiteral(13); } - void assertThrowsUnsupportedPrecisionPrecisionTimestampLiteral(int precision) { + void assertThrowsUnsupportedPrecisionPrecisionTimestampLiteral(final int precision) { assertThrowsExpressionLiteral( Expression.PrecisionTimestampLiteral.builder().value(0).precision(precision).build()); } @@ -228,12 +229,12 @@ void unsupportedPrecisionPrecisionTimestampTZLiteral() { assertThrowsUnsupportedPrecisionPrecisionTimestampTZLiteral(13); } - void assertThrowsUnsupportedPrecisionPrecisionTimestampTZLiteral(int precision) { + void assertThrowsUnsupportedPrecisionPrecisionTimestampTZLiteral(final int precision) { assertThrowsExpressionLiteral( Expression.PrecisionTimestampTZLiteral.builder().value(0).precision(precision).build()); } - void assertThrowsExpressionLiteral(Expression.Literal expr) { + void assertThrowsExpressionLiteral(final Expression.Literal expr) { assertThrows( UnsupportedOperationException.class, () -> expr.accept(converter, Context.newContext())); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java index f11aa55d5..b30427f96 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExtendedExpressionTestBase.java @@ -11,12 +11,13 @@ import org.junit.jupiter.api.Assertions; public class ExtendedExpressionTestBase { - public static String asString(String resource) throws IOException { + public static String asString(final String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } - public static List tpchSchemaCreateStatements(String schemaToLoad) throws IOException { - String[] values = asString(schemaToLoad).split(";"); + public static List tpchSchemaCreateStatements(final String schemaToLoad) + throws IOException { + final String[] values = asString(schemaToLoad).split(";"); return Arrays.stream(values) .filter(t -> !t.trim().isBlank()) .collect(java.util.stream.Collectors.toList()); @@ -26,39 +27,39 @@ public static List tpchSchemaCreateStatements() throws IOException { return tpchSchemaCreateStatements("tpch/schema.sql"); } - protected void assertProtoExtendedExpressionRoundtrip(String expressions) + protected void assertProtoExtendedExpressionRoundtrip(final String expressions) throws SqlParseException, IOException { // proto initial extended expression - io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = + final io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = new SqlExpressionToSubstrait().convert(expressions, tpchSchemaCreateStatements()); asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected void assertProtoExtendedExpressionRoundtrip(String expressions, String schemaToLoad) - throws SqlParseException, IOException { + protected void assertProtoExtendedExpressionRoundtrip( + final String expressions, final String schemaToLoad) throws SqlParseException, IOException { // proto initial extended expression - io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = + final io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = new SqlExpressionToSubstrait() .convert(expressions, tpchSchemaCreateStatements(schemaToLoad)); asserProtoExtendedExpression(extendedExpressionProtoInitial); } - protected void assertProtoExtendedExpressionRoundtrip(String[] expression) + protected void assertProtoExtendedExpressionRoundtrip(final String[] expression) throws SqlParseException, IOException { // proto initial extended expression - io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = + final io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial = new SqlExpressionToSubstrait().convert(expression, tpchSchemaCreateStatements()); asserProtoExtendedExpression(extendedExpressionProtoInitial); } private static void asserProtoExtendedExpression( - io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial) { + final io.substrait.proto.ExtendedExpression extendedExpressionProtoInitial) { // pojo final extended expression - io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = + final io.substrait.extendedexpression.ExtendedExpression extendedExpressionPojoFinal = new ProtoExtendedExpressionConverter().from(extendedExpressionProtoInitial); // proto final extended expression - io.substrait.proto.ExtendedExpression extendedExpressionProtoFinal = + final io.substrait.proto.ExtendedExpression extendedExpressionProtoFinal = new ExtendedExpressionProtoConverter().toProto(extendedExpressionPojoFinal); // round-trip to validate extended expression proto initial equals to final diff --git a/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java b/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java index 8e9824490..c523983a1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FetchTest.java @@ -16,19 +16,19 @@ class FetchTest extends PlanTestBase { @Test void limitOnly() { - Rel rel = b.limit(50, TABLE); + final Rel rel = b.limit(50, TABLE); assertFullRoundTrip(rel); } @Test void offsetOnly() { - Rel rel = b.offset(50, TABLE); + final Rel rel = b.offset(50, TABLE); assertFullRoundTrip(rel); } @Test void offsetAndLimit() { - Rel rel = b.fetch(50, 10, TABLE); + final Rel rel = b.fetch(50, 10, TABLE); assertFullRoundTrip(rel); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java index 102aa2f75..efdcb8c99 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java @@ -52,7 +52,7 @@ void subtractDateIDay() { // mapped to the wrong // Calcite function. // TODO: https://github.com/substrait-io/substrait-java/issues/377 - Expression.ScalarFunctionInvocation expr = + final Expression.ScalarFunctionInvocation expr = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "subtract:date_iday", @@ -60,7 +60,7 @@ void subtractDateIDay() { ExpressionCreator.date(false, 10561), ExpressionCreator.intervalDay(false, 120, 0, 0, 6)); - RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); assertEquals( TypeConverter.DEFAULT.toCalcite(typeFactory, TypeCreator.REQUIRED.DATE), calciteExpr.getType()); @@ -71,7 +71,7 @@ void subtractDateIDay() { @Test void extractTimestampTzScalarFunction() { - ScalarFunctionInvocation reqTstzFn = + final ScalarFunctionInvocation reqTstzFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_tstz_str", @@ -80,11 +80,11 @@ void extractTimestampTzScalarFunction() { Expression.TimestampTZLiteral.builder().value(0).build(), Expression.StrLiteral.builder().value("GMT").build()); - RexNode calciteExpr = reqTstzFn.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = reqTstzFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - RexCall extract = (RexCall) calciteExpr; + final RexCall extract = (RexCall) calciteExpr; assertEquals( "EXTRACT(FLAG(MONTH), 1970-01-01 00:00:00:TIMESTAMP_WITH_LOCAL_TIME_ZONE(6), 'GMT':VARCHAR)", extract.toString()); @@ -92,7 +92,7 @@ void extractTimestampTzScalarFunction() { @Test void extractPrecisionTimestampTzScalarFunction() { - ScalarFunctionInvocation reqPtstzFn = + final ScalarFunctionInvocation reqPtstzFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ptstz_str", @@ -101,11 +101,11 @@ void extractPrecisionTimestampTzScalarFunction() { Expression.PrecisionTimestampTZLiteral.builder().value(0).precision(6).build(), Expression.StrLiteral.builder().value("GMT").build()); - RexNode calciteExpr = reqPtstzFn.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = reqPtstzFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - RexCall extract = (RexCall) calciteExpr; + final RexCall extract = (RexCall) calciteExpr; assertEquals( "EXTRACT(FLAG(MONTH), 1970-01-01 00:00:00:TIMESTAMP_WITH_LOCAL_TIME_ZONE(6), 'GMT':VARCHAR)", extract.toString()); @@ -113,7 +113,7 @@ void extractPrecisionTimestampTzScalarFunction() { @Test void extractTimestampScalarFunction() { - ScalarFunctionInvocation reqTsFn = + final ScalarFunctionInvocation reqTsFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", @@ -121,17 +121,17 @@ void extractTimestampScalarFunction() { EnumArg.builder().value("MONTH").build(), Expression.TimestampLiteral.builder().value(0).build()); - RexNode calciteExpr = reqTsFn.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = reqTsFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - RexCall extract = (RexCall) calciteExpr; + final RexCall extract = (RexCall) calciteExpr; assertEquals("EXTRACT(FLAG(MONTH), 1970-01-01 00:00:00:TIMESTAMP(6))", extract.toString()); } @Test void extractPrecisionTimestampScalarFunction() { - ScalarFunctionInvocation reqPtsFn = + final ScalarFunctionInvocation reqPtsFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_pts", @@ -139,17 +139,17 @@ void extractPrecisionTimestampScalarFunction() { EnumArg.builder().value("MONTH").build(), Expression.PrecisionTimestampLiteral.builder().value(0).precision(6).build()); - RexNode calciteExpr = reqPtsFn.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = reqPtsFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - RexCall extract = (RexCall) calciteExpr; + final RexCall extract = (RexCall) calciteExpr; assertEquals("EXTRACT(FLAG(MONTH), 1970-01-01 00:00:00:TIMESTAMP(6))", extract.toString()); } @Test void extractDateScalarFunction() { - ScalarFunctionInvocation reqDateFn = + final ScalarFunctionInvocation reqDateFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_date", @@ -157,17 +157,17 @@ void extractDateScalarFunction() { EnumArg.builder().value("MONTH").build(), Expression.DateLiteral.builder().value(0).build()); - RexNode calciteExpr = reqDateFn.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = reqDateFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - RexCall extract = (RexCall) calciteExpr; + final RexCall extract = (RexCall) calciteExpr; assertEquals("EXTRACT(FLAG(MONTH), 1970-01-01)", extract.toString()); } @Test void extractTimeScalarFunction() { - ScalarFunctionInvocation reqTimeFn = + final ScalarFunctionInvocation reqTimeFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_time", @@ -175,17 +175,17 @@ void extractTimeScalarFunction() { EnumArg.builder().value("MINUTE").build(), Expression.TimeLiteral.builder().value(0).build()); - RexNode calciteExpr = reqTimeFn.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = reqTimeFn.accept(expressionRexConverter, Context.newContext()); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); - RexCall extract = (RexCall) calciteExpr; + final RexCall extract = (RexCall) calciteExpr; assertEquals("EXTRACT(FLAG(MINUTE), 00:00:00:TIME(6))", extract.toString()); } @Test void unsupportedExtractTimestampTzWithIndexing() { - ScalarFunctionInvocation reqReqTstzFn = + final ScalarFunctionInvocation reqReqTstzFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_tstz_str", @@ -202,7 +202,7 @@ void unsupportedExtractTimestampTzWithIndexing() { @Test void unsupportedExtractPrecisionTimestampTzWithIndexing() { - ScalarFunctionInvocation reqReqPtstzFn = + final ScalarFunctionInvocation reqReqPtstzFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_ptstz_str", @@ -219,7 +219,7 @@ void unsupportedExtractPrecisionTimestampTzWithIndexing() { @Test void unsupportedExtractTimestampWithIndexing() { - ScalarFunctionInvocation reqReqTsFn = + final ScalarFunctionInvocation reqReqTsFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_ts", @@ -235,7 +235,7 @@ void unsupportedExtractTimestampWithIndexing() { @Test void unsupportedExtractPrecisionTimestampWithIndexing() { - ScalarFunctionInvocation reqReqPtsFn = + final ScalarFunctionInvocation reqReqPtsFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_pts", @@ -251,7 +251,7 @@ void unsupportedExtractPrecisionTimestampWithIndexing() { @Test void unsupportedExtractDateWithIndexing() { - ScalarFunctionInvocation reqReqDateFn = + final ScalarFunctionInvocation reqReqDateFn = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", diff --git a/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java b/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java index 22d2a4e99..fc6aad063 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java @@ -9,10 +9,10 @@ class KeyConstraintsTest extends PlanTestBase { @ParameterizedTest @ValueSource(ints = {7}) - void tpcds(int query) throws Exception { - SqlToSubstrait s = new SqlToSubstrait(); - String values = asString("keyconstraints_schema.sql"); - Prepare.CatalogReader catalog = + void tpcds(final int query) throws Exception { + final SqlToSubstrait s = new SqlToSubstrait(); + final String values = asString("keyconstraints_schema.sql"); + final Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog(values); s.convert(asString(String.format("tpcds/queries/%02d.sql", query)), catalog); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java index ada6cfacd..0669bf9b5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LogarithmicFunctionTest.java @@ -10,22 +10,22 @@ class LogarithmicFunctionTest extends PlanTestBase { @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void ln(String column) throws Exception { - String query = String.format("SELECT ln(%s) FROM numbers", column); + void ln(final String column) throws Exception { + final String query = String.format("SELECT ln(%s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void log10(String column) throws Exception { - String query = String.format("SELECT log10(%s) FROM numbers", column); + void log10(final String column) throws Exception { + final String query = String.format("SELECT log10(%s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i64", "fp32", "fp64"}) - void log2(String column) throws Exception { - String query = String.format("SELECT log2(%s) FROM numbers", column); + void log2(final String column) throws Exception { + final String query = String.format("SELECT log2(%s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index ed63f0c47..0ed36d295 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -15,38 +15,39 @@ class NameRoundtripTest extends PlanTestBase { @Test void preserveNamesFromSql() throws Exception { - String createStatement = "CREATE TABLE foo(a BIGINT, b BIGINT)"; - CalciteCatalogReader catalogReader = + final String createStatement = "CREATE TABLE foo(a BIGINT, b BIGINT)"; + final CalciteCatalogReader catalogReader = SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatement); - SubstraitToCalcite substraitToCalcite = + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; - List expectedNames = List.of("a", "B"); + final String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; + final List expectedNames = List.of("a", "B"); - org.apache.calcite.rel.RelRoot calciteRelRoot1 = + final org.apache.calcite.rel.RelRoot calciteRelRoot1 = SubstraitSqlToCalcite.convertQuery(query, catalogReader); assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames()); - io.substrait.plan.Plan.Root substraitRelRoot = + final io.substrait.plan.Plan.Root substraitRelRoot = SubstraitRelVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION); assertEquals(expectedNames, substraitRelRoot.getNames()); - org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot); + final org.apache.calcite.rel.RelRoot calciteRelRoot2 = + substraitToCalcite.convert(substraitRelRoot); assertEquals(expectedNames, calciteRelRoot2.validatedRowType.getFieldNames()); } @Test void preserveNamesFromSubstrait() { - NamedScan rel = + final NamedScan rel = substraitBuilder.namedScan( List.of("foo"), List.of("i64", "struct", "struct0", "struct1"), List.of(R.I64, R.struct(R.FP64, R.STRING))); - Plan.Root planRoot = + final Plan.Root planRoot = Plan.Root.builder().input(rel).names(List.of("i", "s", "s0", "s1")).build(); assertFullRoundTrip(planRoot); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java index f1400526a..7897b7c14 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java @@ -25,15 +25,19 @@ class NestedStructQueryTest extends PlanTestBase { private class TypeHelper { private final RelDataTypeFactory factory; - public TypeHelper(RelDataTypeFactory factory) { + public TypeHelper(final RelDataTypeFactory factory) { this.factory = factory; } - RelDataType struct(String field, RelDataType value) { + RelDataType struct(final String field, final RelDataType value) { return factory.createStructType(Arrays.asList(Pair.of(field, value))); } - RelDataType struct2(String field1, RelDataType value1, String field2, RelDataType value2) { + RelDataType struct2( + final String field1, + final RelDataType value1, + final String field2, + final RelDataType value2) { return factory.createStructType( Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2))); } @@ -46,28 +50,29 @@ RelDataType string() { return factory.createSqlType(SqlTypeName.VARCHAR); } - RelDataType list(RelDataType elementType) { + RelDataType list(final RelDataType elementType) { return factory.createArrayType(elementType, -1); } - RelDataType map(RelDataType key, RelDataType value) { + RelDataType map(final RelDataType key, final RelDataType value) { return factory.createMapType(key, value); } } - private void test(Table table, String query, String expectedExpressionText) + private void test(final Table table, final String query, final String expectedExpressionText) throws SqlParseException, IOException { final Schema schema = new SubstraitSchema(Map.of("my_table", table)); final CalciteCatalogReader catalog = schemaToCatalog("nested", schema); final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(); - Plan plan = toProto(sqlToSubstrait.convert(query, catalog)); - Expression obtainedExpression = + final Plan plan = toProto(sqlToSubstrait.convert(query, catalog)); + final Expression obtainedExpression = plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0); - Expression expectedExpression = TextFormat.parse(expectedExpressionText, Expression.class); + final Expression expectedExpression = + TextFormat.parse(expectedExpressionText, Expression.class); assertEquals(expectedExpression, obtainedExpression); - ProtoPlanConverter converter = new ProtoPlanConverter(); - io.substrait.plan.Plan plan2 = converter.from(plan); + final ProtoPlanConverter converter = new ProtoPlanConverter(); + final io.substrait.plan.Plan plan2 = converter.from(plan); assertPlanRoundtrip(plan2); } @@ -76,18 +81,18 @@ void testNestedStruct() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override - public RelDataType getRowType(RelDataTypeFactory factory) { - TypeHelper helper = new TypeHelper(factory); + public RelDataType getRowType(final RelDataTypeFactory factory) { + final TypeHelper helper = new TypeHelper(factory); return helper.struct2( "x", helper.i32(), "a", helper.i32()); } }; - String query = + final String query = "SELECT\n" + " \"nested\".\"my_table\".\"a\"\n" + "FROM\n" + " \"nested\".\"my_table\";"; - String expectedExpressionText = + final String expectedExpressionText = "selection {\n" + " direct_reference {\n" + " struct_field {\n" @@ -105,21 +110,21 @@ void testNestedStruct2() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override - public RelDataType getRowType(RelDataTypeFactory factory) { - TypeHelper helper = new TypeHelper(factory); + public RelDataType getRowType(final RelDataTypeFactory factory) { + final TypeHelper helper = new TypeHelper(factory); return helper.struct2( "x", helper.i32(), "a", helper.struct("b", helper.i32())); } }; - String query = + final String query = "SELECT\n" + " \"nested\".\"my_table\".\"a\".\"b\"\n" + "FROM\n" + " \"nested\".\"my_table\";"; - String expectedExpressionText = + final String expectedExpressionText = "selection {\n" + " direct_reference {\n" + " struct_field {\n" @@ -142,21 +147,21 @@ void testNestedStruct3() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override - public RelDataType getRowType(RelDataTypeFactory factory) { - TypeHelper helper = new TypeHelper(factory); + public RelDataType getRowType(final RelDataTypeFactory factory) { + final TypeHelper helper = new TypeHelper(factory); return helper.struct2( "aa", helper.i32(), "a", helper.struct("b", helper.struct("c", helper.i32()))); } }; - String query = + final String query = "SELECT\n" + " \"nested\".\"my_table\".\"a\".\"b\".\"c\"\n" + "FROM\n" + " \"nested\".\"my_table\";"; - String expectedExpressionText = + final String expectedExpressionText = "selection {\n" + " direct_reference {\n" + " struct_field {\n" @@ -184,20 +189,20 @@ void testNestedList() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override - public RelDataType getRowType(RelDataTypeFactory factory) { - TypeHelper helper = new TypeHelper(factory); + public RelDataType getRowType(final RelDataTypeFactory factory) { + final TypeHelper helper = new TypeHelper(factory); return helper.struct2("x", helper.i32(), "a", helper.list(helper.i32())); } }; - String query = + final String query = "SELECT\n" + " \"nested\".\"my_table\".\"a\"[1]\n" + "FROM\n" + " \"nested\".\"my_table\";"; - String expectedExpressionText = + final String expectedExpressionText = "selection {\n" + " direct_reference {\n" + " struct_field {\n" @@ -220,8 +225,8 @@ void testNestedList2() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override - public RelDataType getRowType(RelDataTypeFactory factory) { - TypeHelper helper = new TypeHelper(factory); + public RelDataType getRowType(final RelDataTypeFactory factory) { + final TypeHelper helper = new TypeHelper(factory); return helper.struct2( "x", @@ -231,13 +236,13 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } }; - String query = + final String query = "SELECT\n" + " \"nested\".\"my_table\".\"a\"[1][2][3]\n" + "FROM\n" + " \"nested\".\"my_table\";"; - String expectedExpressionText = + final String expectedExpressionText = "selection {\n" + " direct_reference {\n" + " struct_field {\n" @@ -271,9 +276,9 @@ void testProtobufDoc() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override - public RelDataType getRowType(RelDataTypeFactory factory) { + public RelDataType getRowType(final RelDataTypeFactory factory) { - TypeHelper helper = new TypeHelper(factory); + final TypeHelper helper = new TypeHelper(factory); return helper.struct( "a", helper.struct( @@ -284,13 +289,13 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } }; - String query = + final String query = "SELECT\n" + " \"nested\".\"my_table\".a.b[2].c['my_map_key'].x\n" + "FROM\n" + " \"nested\".\"my_table\";"; - String expectedExpressionText = + final String expectedExpressionText = " selection {\n" + " direct_reference {\n" + " struct_field {\n" diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java index b408303a0..76e39fd0e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -18,26 +18,26 @@ class OptimizerIntegrationTest extends PlanTestBase { @Test void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException { - String query = + final String query = "select O_CUSTKEY, count(distinct O_ORDERKEY), count(*) from orders group by O_CUSTKEY"; // verify that the query works generally assertFullRoundTrip(query); - RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG); - RelNode originalPlan = relRoot.rel; + final RelRoot relRoot = SubstraitSqlToCalcite.convertQuery(query, TPCH_CATALOG); + final RelNode originalPlan = relRoot.rel; // Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule. // This will introduce a SqlSumEmptyIsZeroAggFunction to the plan. // This function does not have a mapping to Substrait. // SubstraitSumEmptyIsZeroAggFunction is the variant which has a mapping. // See io.substrait.isthmus.AggregateFunctions for details - HepProgram program = + final HepProgram program = new HepProgramBuilder() .addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) .build(); - HepPlanner planner = new HepPlanner(program); + final HepPlanner planner = new HepPlanner(program); planner.setRoot(originalPlan); - RelNode newPlan = planner.findBestExp(); + final RelNode newPlan = planner.findBestExp(); assertDoesNotThrow( () -> diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index cce58e207..3de4c1006 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -51,7 +51,7 @@ public class PlanTestBase { static { try { - String tpchCreateStatements = asString("tpch/schema.sql"); + final String tpchCreateStatements = asString("tpch/schema.sql"); TPCH_CATALOG = SubstraitCreateStatementParser.processCreateStatementsToCatalog(tpchCreateStatements); } catch (IOException | SqlParseException e) { @@ -63,33 +63,35 @@ public class PlanTestBase { protected static CalciteCatalogReader TPCDS_CATALOG = PlanTestBase.schemaToCatalog("tpcds", TPCDS_SCHEMA); - public static String asString(String resource) throws IOException { + public static String asString(final String resource) throws IOException { return Resources.toString(Resources.getResource(resource), Charsets.UTF_8); } - protected Plan assertProtoPlanRoundrip(String query) throws SqlParseException { + protected Plan assertProtoPlanRoundrip(final String query) throws SqlParseException { return assertProtoPlanRoundrip(query, new SqlToSubstrait()); } - protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s) throws SqlParseException { + protected Plan assertProtoPlanRoundrip(final String query, final SqlToSubstrait s) + throws SqlParseException { return assertProtoPlanRoundrip(query, s, TPCH_CATALOG); } - protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, String createStatements) + protected Plan assertProtoPlanRoundrip( + final String query, final SqlToSubstrait s, final String createStatements) throws SqlParseException { - Prepare.CatalogReader catalog = + final Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatements); return assertProtoPlanRoundrip(query, s, catalog); } protected Plan assertProtoPlanRoundrip( - String query, SqlToSubstrait s, Prepare.CatalogReader catalogReader) + final String query, final SqlToSubstrait s, final Prepare.CatalogReader catalogReader) throws SqlParseException { - Plan plan1 = s.convert(query, catalogReader); - io.substrait.proto.Plan protoPlan1 = toProto(plan1); + final Plan plan1 = s.convert(query, catalogReader); + final io.substrait.proto.Plan protoPlan1 = toProto(plan1); - Plan plan2 = new ProtoPlanConverter(extensions).from(protoPlan1); - io.substrait.proto.Plan protoPlan2 = toProto(plan2); + final Plan plan2 = new ProtoPlanConverter(extensions).from(protoPlan1); + final io.substrait.proto.Plan protoPlan2 = toProto(plan2); assertEquals(protoPlan1, protoPlan2); assertEquals(plan1.getRoots().size(), plan2.getRoots().size()); @@ -102,57 +104,57 @@ protected Plan assertProtoPlanRoundrip( return plan2; } - protected void assertPlanRoundtrip(Plan plan) { - io.substrait.proto.Plan protoPlan1 = toProto(plan); - io.substrait.proto.Plan protoPlan2 = toProto(new ProtoPlanConverter().from(protoPlan1)); + protected void assertPlanRoundtrip(final Plan plan) { + final io.substrait.proto.Plan protoPlan1 = toProto(plan); + final io.substrait.proto.Plan protoPlan2 = toProto(new ProtoPlanConverter().from(protoPlan1)); assertEquals(protoPlan1, protoPlan2); } - protected RelRoot assertSqlSubstraitRelRoundTrip(String query) throws Exception { + protected RelRoot assertSqlSubstraitRelRoundTrip(final String query) throws Exception { return assertSqlSubstraitRelRoundTrip(query, TPCH_CATALOG); } - protected RelRoot assertSqlSubstraitRelRoundTrip(String query, String createStatements) - throws Exception { - CalciteCatalogReader catalogReader = + protected RelRoot assertSqlSubstraitRelRoundTrip( + final String query, final String createStatements) throws Exception { + final CalciteCatalogReader catalogReader = SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatements); return assertSqlSubstraitRelRoundTrip(query, catalogReader); } protected RelRoot assertSqlSubstraitRelRoundTrip( - String query, Prepare.CatalogReader catalogReader) throws Exception { + final String query, final Prepare.CatalogReader catalogReader) throws Exception { // sql <--> substrait round trip test. // Assert (sql -> calcite -> substrait) and (sql -> substrait -> calcite -> substrait) are same. // Return list of sql -> Substrait rel -> Calcite rel. - SqlToSubstrait s2s = new SqlToSubstrait(); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + final SqlToSubstrait s2s = new SqlToSubstrait(); + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); // 1. SQL -> Substrait Plan - Plan plan1 = s2s.convert(query, catalogReader); + final Plan plan1 = s2s.convert(query, catalogReader); // 2. Substrait Plan -> Substrait Rel - Plan.Root pojo1 = plan1.getRoots().get(0); + final Plan.Root pojo1 = plan1.getRoots().get(0); // 3. Substrait Rel -> Calcite RelNode - RelRoot relRoot2 = substraitToCalcite.convert(pojo1); + final RelRoot relRoot2 = substraitToCalcite.convert(pojo1); // 4. Calcite RelNode -> Substrait Rel - Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions); + final Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions); assertEquals(pojo1, pojo2); return relRoot2; } @Beta - protected void assertFullRoundTrip(String query) throws SqlParseException { + protected void assertFullRoundTrip(final String query) throws SqlParseException { assertFullRoundTrip(query, TPCH_CATALOG); } @Beta - protected void assertFullRoundTrip(String query, String createStatements) + protected void assertFullRoundTrip(final String query, final String createStatements) throws SqlParseException { - CalciteCatalogReader catalogReader = + final CalciteCatalogReader catalogReader = SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatements); assertFullRoundTrip(query, catalogReader); } @@ -170,21 +172,22 @@ protected void assertFullRoundTrip(String query, String createStatements) *

  • Substrait POJO 2 == Substrait POJO 3 * */ - protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalogReader) - throws SqlParseException { - ExtensionCollector extensionCollector = new ExtensionCollector(); + protected void assertFullRoundTrip( + final String sqlQuery, final Prepare.CatalogReader catalogReader) throws SqlParseException { + final ExtensionCollector extensionCollector = new ExtensionCollector(); // SQL -> Calcite 1 - RelRoot calcite1 = SubstraitSqlToCalcite.convertQuery(sqlQuery, catalogReader); + final RelRoot calcite1 = SubstraitSqlToCalcite.convertQuery(sqlQuery, catalogReader); // Calcite 1 -> Substrait POJO 1 - Plan.Root root1 = SubstraitRelVisitor.convert(calcite1, extensions); + final Plan.Root root1 = SubstraitRelVisitor.convert(calcite1, extensions); // Substrait Root 1 -> Substrait Proto - io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(root1); + final io.substrait.proto.RelRoot proto = + new RelProtoConverter(extensionCollector).toProto(root1); // Substrait Proto -> Substrait Root 2 - Plan.Root root2 = new ProtoRelConverter(extensionCollector, extensions).from(proto); + final Plan.Root root2 = new ProtoRelConverter(extensionCollector, extensions).from(proto); // Verify that roots are the same assertEquals(root1, root2); @@ -193,13 +196,13 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory, catalogReader); - RelRoot calcite2 = substraitToCalcite.convert(root2); + final RelRoot calcite2 = substraitToCalcite.convert(root2); // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to // do so assertNotNull(calcite2); // Calcite 2 -> Substrait Root 3 - Plan.Root root3 = SubstraitRelVisitor.convert(calcite2, extensions); + final Plan.Root root3 = SubstraitRelVisitor.convert(calcite2, extensions); // Verify that POJOs are the same assertEquals(root1, root3); @@ -229,21 +232,22 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo * */ protected void assertFullRoundTripWithIdentityProjectionWorkaround( - String sqlQuery, Prepare.CatalogReader catalogReader) throws SqlParseException { - ExtensionCollector extensionCollector = new ExtensionCollector(); + final String sqlQuery, final Prepare.CatalogReader catalogReader) throws SqlParseException { + final ExtensionCollector extensionCollector = new ExtensionCollector(); // Preparation // SQL -> Calcite 0 - RelRoot calcite0 = SubstraitSqlToCalcite.convertQuery(sqlQuery, catalogReader); + final RelRoot calcite0 = SubstraitSqlToCalcite.convertQuery(sqlQuery, catalogReader); // Calcite 0 -> Substrait POJO 0 - Plan.Root root0 = SubstraitRelVisitor.convert(calcite0, extensions); + final Plan.Root root0 = SubstraitRelVisitor.convert(calcite0, extensions); // Substrait POJO 0 -> Substrait Proto 0 - io.substrait.proto.RelRoot proto0 = new RelProtoConverter(extensionCollector).toProto(root0); + final io.substrait.proto.RelRoot proto0 = + new RelProtoConverter(extensionCollector).toProto(root0); // Substrait Proto -> Substrait POJO 1 - Plan.Root root1 = new ProtoRelConverter(extensionCollector, extensions).from(proto0); + final Plan.Root root1 = new ProtoRelConverter(extensionCollector, extensions).from(proto0); // Verify that POJOs are the same assertEquals(root0, root1); @@ -252,23 +256,24 @@ protected void assertFullRoundTripWithIdentityProjectionWorkaround( new SubstraitToCalcite(extensions, typeFactory, catalogReader); // Substrait POJO 1 -> Calcite 1 - RelRoot calcite1 = substraitToCalcite.convert(root1); + final RelRoot calcite1 = substraitToCalcite.convert(root1); // End Preparation // Calcite 1 -> Substrait POJO 2 - Plan.Root root2 = SubstraitRelVisitor.convert(calcite1, extensions); + final Plan.Root root2 = SubstraitRelVisitor.convert(calcite1, extensions); // Substrait POJO 2 -> Substrait Proto 1 - io.substrait.proto.RelRoot proto1 = new RelProtoConverter(extensionCollector).toProto(root2); + final io.substrait.proto.RelRoot proto1 = + new RelProtoConverter(extensionCollector).toProto(root2); // Substrait Proto1 -> Substrait POJO 3 - Plan.Root root3 = new ProtoRelConverter(extensionCollector, extensions).from(proto1); + final Plan.Root root3 = new ProtoRelConverter(extensionCollector, extensions).from(proto1); // Substrait POJO 3 -> Calcite 2 - RelRoot calcite2 = substraitToCalcite.convert(root3); + final RelRoot calcite2 = substraitToCalcite.convert(root3); // Calcite 2 -> Substrait POJO 4 - Plan.Root root4 = SubstraitRelVisitor.convert(calcite2, extensions); + final Plan.Root root4 = SubstraitRelVisitor.convert(calcite2, extensions); // Verify that POJOs are the same assertEquals(root2, root4); @@ -282,25 +287,25 @@ protected void assertFullRoundTripWithIdentityProjectionWorkaround( *
  • From POJO to Calcite and back * */ - protected void assertFullRoundTrip(Rel pojo1) { + protected void assertFullRoundTrip(final Rel pojo1) { // TODO: reuse the Plan.Root based assertFullRoundTrip by generating names - ExtensionCollector extensionCollector = new ExtensionCollector(); + final ExtensionCollector extensionCollector = new ExtensionCollector(); // Substrait POJO 1 -> Substrait Proto - io.substrait.proto.Rel proto = new RelProtoConverter(extensionCollector).toProto(pojo1); + final io.substrait.proto.Rel proto = new RelProtoConverter(extensionCollector).toProto(pojo1); // Substrait Proto -> Substrait Pojo 2 - io.substrait.relation.Rel pojo2 = + final io.substrait.relation.Rel pojo2 = new ProtoRelConverter(extensionCollector, extensions).from(proto); // Verify that POJOs are the same assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelNode calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + final RelNode calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, extensions); + final io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, extensions); // Verify that POJOs are the same assertEquals(pojo1, pojo3); @@ -314,68 +319,70 @@ protected void assertFullRoundTrip(Rel pojo1) { *
  • From POJO to Calcite and back * */ - protected void assertFullRoundTrip(Plan.Root pojo1) { - ExtensionCollector extensionCollector = new ExtensionCollector(); + protected void assertFullRoundTrip(final Plan.Root pojo1) { + final ExtensionCollector extensionCollector = new ExtensionCollector(); // Substrait POJO 1 -> Substrait Proto - io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1); + final io.substrait.proto.RelRoot proto = + new RelProtoConverter(extensionCollector).toProto(pojo1); // Substrait Proto -> Substrait Pojo 2 - io.substrait.plan.Plan.Root pojo2 = + final io.substrait.plan.Plan.Root pojo2 = new ProtoRelConverter(extensionCollector, extensions).from(proto); // Verify that POJOs are the same assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelRoot calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + final RelRoot calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); // Calcite -> Substrait POJO 3 - io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, extensions); + final io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, extensions); // Verify that POJOs are the same assertEquals(pojo1, pojo3); } - protected void assertRowMatch(RelDataType actual, Type... expected) { + protected void assertRowMatch(final RelDataType actual, final Type... expected) { assertRowMatch(actual, Arrays.asList(expected)); } - protected void assertRowMatch(RelDataType actual, List expected) { - Type type = TypeConverter.DEFAULT.toSubstrait(actual); + protected void assertRowMatch(final RelDataType actual, final List expected) { + final Type type = TypeConverter.DEFAULT.toSubstrait(actual); assertInstanceOf(Type.Struct.class, type); - Type.Struct struct = (Type.Struct) type; + final Type.Struct struct = (Type.Struct) type; assertEquals(expected, struct.fields()); } - protected Plan toSubstraitPlan(String sql, CalciteCatalogReader catalog) + protected Plan toSubstraitPlan(final String sql, final CalciteCatalogReader catalog) throws SqlParseException { return new SqlToSubstrait().convert(sql, catalog); } - protected String toSql(io.substrait.proto.Plan protoPlan) { - Plan plan = new ProtoPlanConverter(extensions).from(protoPlan); + protected String toSql(final io.substrait.proto.Plan protoPlan) { + final Plan plan = new ProtoPlanConverter(extensions).from(protoPlan); return toSql(plan); } - protected String toSql(Plan plan) { - List roots = plan.getRoots(); + protected String toSql(final Plan plan) { + final List roots = plan.getRoots(); assertEquals(1, roots.size(), "number of roots"); - Root root = roots.get(0); - RelRoot relRoot = new SubstraitToCalcite(extensions, typeFactory).convert(root); - RelNode project = relRoot.project(true); + final Root root = roots.get(0); + final RelRoot relRoot = new SubstraitToCalcite(extensions, typeFactory).convert(root); + final RelNode project = relRoot.project(true); return SubstraitSqlDialect.toSql(project).getSql(); } - protected io.substrait.proto.Plan toProto(Plan plan) { + protected io.substrait.proto.Plan toProto(final Plan plan) { return new PlanProtoConverter().toProto(plan); } - protected static CalciteCatalogReader schemaToCatalog(String schemaName, Schema schema) { - CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); + protected static CalciteCatalogReader schemaToCatalog( + final String schemaName, final Schema schema) { + final CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); rootSchema.add(schemaName, schema); - List defaultSchema = List.of(schemaName); + final List defaultSchema = List.of(schemaName); return new CalciteCatalogReader( rootSchema, defaultSchema, diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index da8423c03..01545d047 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -20,8 +20,8 @@ class ProtoPlanConverterTest extends PlanTestBase { - private io.substrait.proto.Plan getProtoPlan(String query1) throws SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); + private io.substrait.proto.Plan getProtoPlan(final String query1) throws SqlParseException { + final SqlToSubstrait s = new SqlToSubstrait(); return toProto(s.convert(query1, TPCH_CATALOG)); } @@ -35,7 +35,7 @@ void simpleSelect() throws IOException, SqlParseException { assertProtoPlanRoundrip("select l_orderkey,l_extendedprice from lineitem"); } - private static void assertAggregateInvocationDistinct(io.substrait.proto.Plan plan) { + private static void assertAggregateInvocationDistinct(final io.substrait.proto.Plan plan) { assertEquals( AggregateFunction.AggregationInvocation.AGGREGATION_INVOCATION_DISTINCT, plan.getRelations(0) @@ -50,8 +50,8 @@ private static void assertAggregateInvocationDistinct(io.substrait.proto.Plan pl @Test void distinctCount() throws IOException, SqlParseException { - String distinctQuery = "select count(DISTINCT L_ORDERKEY) from lineitem"; - io.substrait.proto.Plan protoPlan = getProtoPlan(distinctQuery); + final String distinctQuery = "select count(DISTINCT L_ORDERKEY) from lineitem"; + final io.substrait.proto.Plan protoPlan = getProtoPlan(distinctQuery); assertAggregateInvocationDistinct(protoPlan); assertAggregateInvocationDistinct(toProto(new ProtoPlanConverter().from(protoPlan))); } @@ -63,40 +63,40 @@ void filter() throws IOException, SqlParseException { @Test void crossJoin() throws IOException, SqlParseException { - int[] counter = new int[1]; - RelCopyOnWriteVisitor crossJoinCountingVisitor = + final int[] counter = new int[1]; + final RelCopyOnWriteVisitor crossJoinCountingVisitor = new RelCopyOnWriteVisitor() { @Override - public Optional visit(Cross cross, EmptyVisitationContext context) + public Optional visit(final Cross cross, final EmptyVisitationContext context) throws RuntimeException { counter[0]++; return super.visit(cross, context); } }; - ImmutableFeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); + final ImmutableFeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); - String query1 = + final String query1 = "select\n" + " c.c_custKey,\n" + " o.o_custkey\n" + "from\n" + " \"customer\" c cross join\n" + " \"orders\" o"; - Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait(featureBoard)); + final Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait(featureBoard)); plan1 .getRoots() .forEach( t -> t.getInput().accept(crossJoinCountingVisitor, EmptyVisitationContext.INSTANCE)); assertEquals(1, counter[0]); - String query2 = + final String query2 = "select\n" + " c.c_custKey,\n" + " o.o_custkey\n" + "from\n" + " \"customer\" c,\n" + " \"orders\" o"; - Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait(featureBoard)); + final Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait(featureBoard)); plan2 .getRoots() .forEach( @@ -137,7 +137,7 @@ void joinAggSortLimit() throws IOException, SqlParseException { @ParameterizedTest @MethodSource("io.substrait.isthmus.utils.SetUtils#setTestConfig") - void setTest(Set.SetOp op, boolean multi) throws Exception { + void setTest(final Set.SetOp op, final boolean multi) throws Exception { assertProtoPlanRoundrip(SetUtils.getSetQuery(op, multi)); } @@ -167,7 +167,7 @@ void notInPredicateCorrelatedSubquery() throws IOException, SqlParseException { @Test void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { - String sql = + final String sql = "SELECT p_partkey\n" + "FROM part p\n" + "WHERE EXISTS\n" diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index 5f3fc09db..00118dc7d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -76,14 +76,14 @@ class RelCopyOnWriteVisitorTest extends PlanTestBase { + "from\n" + " \"orders\" o\n"; - private Plan buildPlanFromQuery(String query) throws IOException, SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); + private Plan buildPlanFromQuery(final String query) throws IOException, SqlParseException { + final SqlToSubstrait s = new SqlToSubstrait(); return s.convert(query, TPCH_CATALOG); } @Test void hasTableReference() throws IOException, SqlParseException { - Plan plan = + final Plan plan = buildPlanFromQuery( "SELECT p_partkey\n" + "FROM part p\n" @@ -96,7 +96,7 @@ void hasTableReference() throws IOException, SqlParseException { + " FROM partsupp ps\n" + " WHERE ps.ps_partkey = p.p_partkey\n" + " AND PS.ps_suppkey = l.l_suppkey))"); - HasTableReference action = new HasTableReference(); + final HasTableReference action = new HasTableReference(); assertTrue(action.hasTableReference(plan, "PARTSUPP")); assertTrue(action.hasTableReference(plan, "LINEITEM")); assertTrue(action.hasTableReference(plan, "PART")); @@ -105,17 +105,17 @@ void hasTableReference() throws IOException, SqlParseException { @Test void countCountDistincts() throws IOException, SqlParseException { - Plan plan = buildPlanFromQuery(COUNT_DISTINCT_SUBBQUERY); + final Plan plan = buildPlanFromQuery(COUNT_DISTINCT_SUBBQUERY); assertEquals(2, new CountCountDistinct().getCountDistincts(plan)); } @Test void replaceCountDistincts() throws IOException, SqlParseException { - Plan oldPlan = buildPlanFromQuery(COUNT_DISTINCT_SUBBQUERY); + final Plan oldPlan = buildPlanFromQuery(COUNT_DISTINCT_SUBBQUERY); assertEquals(2, new CountCountDistinct().getCountDistincts(oldPlan)); assertEquals(0, new CountApproxCountDistinct().getApproxCountDistincts(oldPlan)); - ReplaceCountDistinctWithApprox action = new ReplaceCountDistinctWithApprox(); - Plan newPlan = action.modify(oldPlan).orElse(oldPlan); + final ReplaceCountDistinctWithApprox action = new ReplaceCountDistinctWithApprox(); + final Plan newPlan = action.modify(oldPlan).orElse(oldPlan); assertEquals(2, new CountApproxCountDistinct().getApproxCountDistincts(newPlan)); assertEquals(0, new CountCountDistinct().getCountDistincts(newPlan)); assertPlanRoundtrip(newPlan); @@ -123,45 +123,46 @@ void replaceCountDistincts() throws IOException, SqlParseException { @Test void approximateCountDistinct() throws IOException, SqlParseException { - Plan oldPlan = + final Plan oldPlan = buildPlanFromQuery( "select count(distinct l_discount), count(distinct l_tax) from lineitem"); assertEquals(2, new CountCountDistinct().getCountDistincts(oldPlan)); assertEquals(0, new CountApproxCountDistinct().getApproxCountDistincts(oldPlan)); - ReplaceCountDistinctWithApprox action = new ReplaceCountDistinctWithApprox(); - Plan newPlan = action.modify(oldPlan).orElse(oldPlan); + final ReplaceCountDistinctWithApprox action = new ReplaceCountDistinctWithApprox(); + final Plan newPlan = action.modify(oldPlan).orElse(oldPlan); assertEquals(2, new CountApproxCountDistinct().getApproxCountDistincts(newPlan)); assertEquals(0, new CountCountDistinct().getCountDistincts(newPlan)); assertPlanRoundtrip(newPlan); // convert newPlan back to sql - Rel pojoRel = newPlan.getRoots().get(0).getInput(); - RelNode relnodeRoot = new SubstraitToSql().substraitRelToCalciteRel(pojoRel, TPCH_CATALOG); - String newSql = SubstraitSqlDialect.toSql(relnodeRoot).getSql(); + final Rel pojoRel = newPlan.getRoots().get(0).getInput(); + final RelNode relnodeRoot = + new SubstraitToSql().substraitRelToCalciteRel(pojoRel, TPCH_CATALOG); + final String newSql = SubstraitSqlDialect.toSql(relnodeRoot).getSql(); assertTrue(newSql.toUpperCase().contains("APPROX_COUNT_DISTINCT")); } @Test void countCountDistinctsUnion() throws IOException, SqlParseException { - Plan plan = buildPlanFromQuery(UNION_DISTINCT_COUNT_QUERY); + final Plan plan = buildPlanFromQuery(UNION_DISTINCT_COUNT_QUERY); assertEquals(2, new CountCountDistinct().getCountDistincts(plan)); } @Test void replaceCountDistinctsInUnion() throws IOException, SqlParseException { - Plan oldPlan = buildPlanFromQuery(UNION_DISTINCT_COUNT_QUERY); + final Plan oldPlan = buildPlanFromQuery(UNION_DISTINCT_COUNT_QUERY); assertEquals(2, new CountCountDistinct().getCountDistincts(oldPlan)); assertEquals(0, new CountApproxCountDistinct().getApproxCountDistincts(oldPlan)); - ReplaceCountDistinctWithApprox action = new ReplaceCountDistinctWithApprox(); - Plan newPlan = action.modify(oldPlan).orElse(oldPlan); + final ReplaceCountDistinctWithApprox action = new ReplaceCountDistinctWithApprox(); + final Plan newPlan = action.modify(oldPlan).orElse(oldPlan); assertEquals(2, new CountApproxCountDistinct().getApproxCountDistincts(newPlan)); assertEquals(0, new CountCountDistinct().getCountDistincts(newPlan)); assertPlanRoundtrip(newPlan); } private static class HasTableReference { - public boolean hasTableReference(Plan plan, String name) { - HasTableReferenceVisitor visitor = new HasTableReferenceVisitor(Arrays.asList(name)); + public boolean hasTableReference(final Plan plan, final String name) { + final HasTableReferenceVisitor visitor = new HasTableReferenceVisitor(Arrays.asList(name)); plan.getRoots().stream() .forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE)); return (visitor.hasTableReference()); @@ -171,7 +172,7 @@ private class HasTableReferenceVisitor extends RelCopyOnWriteVisitor tableName; - public HasTableReferenceVisitor(List tableName) { + public HasTableReferenceVisitor(final List tableName) { this.tableName = tableName; } @@ -180,7 +181,7 @@ public boolean hasTableReference() { } @Override - public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { + public Optional visit(final NamedScan namedScan, final EmptyVisitationContext context) { this.hasTableReference |= namedScan.getNames().equals(tableName); return super.visit(namedScan, context); } @@ -189,8 +190,8 @@ public Optional visit(NamedScan namedScan, EmptyVisitationContext context) private static class CountCountDistinct { - public int getCountDistincts(Plan plan) { - CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); + public int getCountDistincts(final Plan plan) { + final CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); plan.getRoots().stream() .forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE)); return visitor.getCountDistincts(); @@ -204,7 +205,7 @@ public int getCountDistincts() { } @Override - public Optional visit(Aggregate aggregate, EmptyVisitationContext context) { + public Optional visit(final Aggregate aggregate, final EmptyVisitationContext context) { countDistincts += aggregate.getMeasures().stream() .filter( @@ -221,8 +222,8 @@ public Optional visit(Aggregate aggregate, EmptyVisitationContext context) private static class CountApproxCountDistinct { - public int getApproxCountDistincts(Plan plan) { - CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); + public int getApproxCountDistincts(final Plan plan) { + final CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); plan.getRoots().stream() .forEach(r -> r.getInput().accept(visitor, EmptyVisitationContext.INSTANCE)); return visitor.getApproxCountDistincts(); @@ -236,7 +237,7 @@ public int getApproxCountDistincts() { } @Override - public Optional visit(Aggregate aggregate, EmptyVisitationContext context) { + public Optional visit(final Aggregate aggregate, final EmptyVisitationContext context) { aproxCountDistincts += aggregate.getMeasures().stream() .filter( @@ -255,7 +256,7 @@ public ReplaceCountDistinctWithApprox() { new ReplaceCountDistinctWithApproxVisitor(DefaultExtensionCatalog.DEFAULT_COLLECTION); } - public Optional modify(Plan plan) { + public Optional modify(final Plan plan) { return CopyOnWriteUtils.transformList( plan.getRoots(), null, @@ -272,13 +273,13 @@ private static class ReplaceCountDistinctWithApproxVisitor private final SimpleExtension.AggregateFunctionVariant approxFunc; public ReplaceCountDistinctWithApproxVisitor( - SimpleExtension.ExtensionCollection extensionCollection) { + final SimpleExtension.ExtensionCollection extensionCollection) { this.approxFunc = Objects.requireNonNull(extensionCollection.getAggregateFunction(APPROX_COUNT_DISTINCT)); } @Override - public Optional visit(Aggregate aggregate, EmptyVisitationContext context) { + public Optional visit(final Aggregate aggregate, final EmptyVisitationContext context) { return CopyOnWriteUtils .transformList( aggregate.getMeasures(), diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java b/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java index 74ad56388..5641011d6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCreator.java @@ -27,16 +27,16 @@ public class RelCreator { - private RelOptCluster cluster; + private final RelOptCluster cluster; private CatalogReader catalog; public RelCreator() { this(null); } - public RelCreator(CatalogReader catalogReader) { + public RelCreator(final CatalogReader catalogReader) { if (catalogReader == null) { - CalciteSchema schema = CalciteSchema.createRootSchema(false); + final CalciteSchema schema = CalciteSchema.createRootSchema(false); catalog = new CalciteCatalogReader( schema, @@ -47,30 +47,31 @@ public RelCreator(CatalogReader catalogReader) { catalog = catalogReader; } - VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.EMPTY_CONTEXT); + final VolcanoPlanner planner = + new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.EMPTY_CONTEXT); cluster = RelOptCluster.create(planner, new RexBuilder(SubstraitTypeSystem.TYPE_FACTORY)); } - public RelRoot parse(String sql) { + public RelRoot parse(final String sql) { try { - SqlParser parser = SqlParser.create(sql, SqlParser.Config.DEFAULT); - SqlNode parsed = parser.parseQuery(); + final SqlParser parser = SqlParser.create(sql, SqlParser.Config.DEFAULT); + final SqlNode parsed = parser.parseQuery(); cluster.setMetadataQuerySupplier( () -> new RelMetadataQuery( new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE))); - SqlValidator validator = + final SqlValidator validator = new Validator(catalog, cluster.getTypeFactory(), SqlValidator.Config.DEFAULT); - SqlToRelConverter.Config converterConfig = + final SqlToRelConverter.Config converterConfig = SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false); - SqlToRelConverter converter = + final SqlToRelConverter converter = new SqlToRelConverter( null, validator, catalog, cluster, StandardConvertletTable.INSTANCE, converterConfig); - RelRoot root = converter.convertQuery(parsed, true, true); + final RelRoot root = converter.convertQuery(parsed, true, true); return root; - } catch (SqlParseException e) { + } catch (final SqlParseException e) { throw new IllegalArgumentException(e); } } @@ -90,7 +91,9 @@ public RelDataTypeFactory typeFactory() { private static final class Validator extends SqlValidatorImpl { public Validator( - SqlValidatorCatalogReader catalogReader, RelDataTypeFactory typeFactory, Config config) { + final SqlValidatorCatalogReader catalogReader, + final RelDataTypeFactory typeFactory, + final Config config) { super(SqlStdOperatorTable.instance(), catalogReader, typeFactory, config); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index fc1c4d812..40ac66a6f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -44,46 +44,46 @@ class RelExtensionRoundtripTest extends PlanTestBase { @Test void extensionLeafRelDetailTest() { - ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(1)); - ImmutableExtensionLeaf rel = ExtensionLeaf.from(detail).build(); + final ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(1)); + final ImmutableExtensionLeaf rel = ExtensionLeaf.from(detail).build(); roundtrip(rel); } @Test void extensionSingleRelDetailTest() { - ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(2)); - ImmutableExtensionSingle rel = + final ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(2)); + final ImmutableExtensionSingle rel = ExtensionSingle.from(detail, substraitBuilder.emptyScan()).build(); roundtrip(rel); } @Test void extensionMultiRelDetailTest() { - ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(3)); - ImmutableExtensionMulti rel = + final ColumnAppendDetail detail = new ColumnAppendDetail(substraitBuilder.i32(3)); + final ImmutableExtensionMulti rel = ExtensionMulti.from(detail, substraitBuilder.emptyScan(), substraitBuilder.emptyScan()) .build(); roundtrip(rel); } - void roundtrip(Rel pojo1) { + void roundtrip(final Rel pojo1) { // Substrait POJO 1 -> Substrait Proto - io.substrait.proto.Rel proto = + final io.substrait.proto.Rel proto = pojo1.accept( new RelProtoConverter(new ExtensionCollector()), EmptyVisitationContext.INSTANCE); // Substrait Proto -> Substrait POJO 2 - Rel pojo2 = (new CustomProtoRelConverter(new ExtensionCollector())).from(proto); + final Rel pojo2 = (new CustomProtoRelConverter(new ExtensionCollector())).from(proto); assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelNode calcite = + final RelNode calcite = pojo2.accept( new CustomSubstraitRelNodeConverter(extensions, typeFactory, builder), Context.newContext()); // Calcite -> Substrait POJO 3 - Rel pojo3 = (new CustomSubstraitRelVisitor(typeFactory, extensions)).apply(calcite); + final Rel pojo3 = (new CustomSubstraitRelVisitor(typeFactory, extensions)).apply(calcite); assertEquals(pojo1, pojo3); } @@ -91,7 +91,7 @@ static class ColumnAppendDetail implements Extension.LeafRelDetail, Extension.SingleRelDetail, Extension.MultiRelDetail { Expression.Literal literal; - ColumnAppendDetail(Expression.Literal literal) { + ColumnAppendDetail(final Expression.Literal literal) { this.literal = literal; } @@ -103,7 +103,7 @@ public Type.Struct deriveRecordType() { @Override // SingleRelDetail - public Type.Struct deriveRecordType(Rel input) { + public Type.Struct deriveRecordType(final Rel input) { return Type.Struct.builder() .nullable(false) .addAllFields(input.getRecordType().fields()) @@ -113,20 +113,20 @@ public Type.Struct deriveRecordType(Rel input) { @Override // MultiRelDetail - public Type.Struct deriveRecordType(List inputs) { - ImmutableType.Struct.Builder builder = Type.Struct.builder().nullable(false); - for (Rel input : inputs) { + public Type.Struct deriveRecordType(final List inputs) { + final ImmutableType.Struct.Builder builder = Type.Struct.builder().nullable(false); + for (final Rel input : inputs) { builder.addAllFields(input.getRecordType().fields()); } return builder.addFields(literal.getType()).build(); } @Override - public Any toProto(RelProtoConverter converter) { + public Any toProto(final RelProtoConverter converter) { // the conversion of the literal in the detail requires the presence of the RelProtoConverter - io.substrait.proto.Expression lit = + final io.substrait.proto.Expression lit = converter.getExpressionProtoConverter().toProto(this.literal); - io.substrait.isthmus.extensions.test.protobuf.ColumnAppendDetail inner = + final io.substrait.isthmus.extensions.test.protobuf.ColumnAppendDetail inner = io.substrait.isthmus.extensions.test.protobuf.ColumnAppendDetail.newBuilder() .setLiteral(lit.getLiteral()) .build(); @@ -134,9 +134,9 @@ public Any toProto(RelProtoConverter converter) { } @Override - public boolean equals(Object o) { + public boolean equals(final Object o) { if (o == null || getClass() != o.getClass()) return false; - ColumnAppendDetail that = (ColumnAppendDetail) o; + final ColumnAppendDetail that = (ColumnAppendDetail) o; return Objects.equals(literal, that.literal); } @@ -157,15 +157,15 @@ public String toString() { */ static class CustomProtoRelConverter extends ProtoRelConverter { - public CustomProtoRelConverter(ExtensionLookup lookup) { + public CustomProtoRelConverter(final ExtensionLookup lookup) { super(lookup); } - ColumnAppendDetail unpack(Any any) { + ColumnAppendDetail unpack(final Any any) { try { - io.substrait.isthmus.extensions.test.protobuf.ColumnAppendDetail proto = + final io.substrait.isthmus.extensions.test.protobuf.ColumnAppendDetail proto = any.unpack(io.substrait.isthmus.extensions.test.protobuf.ColumnAppendDetail.class); - Literal literal = + final Literal literal = (new ProtoExpressionConverter( lookup, extensions, Type.Struct.builder().nullable(false).build(), this) .from(proto.getLiteral())); @@ -176,17 +176,17 @@ ColumnAppendDetail unpack(Any any) { } @Override - protected Extension.LeafRelDetail detailFromExtensionLeafRel(Any any) { + protected Extension.LeafRelDetail detailFromExtensionLeafRel(final Any any) { return unpack(any); } @Override - protected Extension.SingleRelDetail detailFromExtensionSingleRel(Any any) { + protected Extension.SingleRelDetail detailFromExtensionSingleRel(final Any any) { return unpack(any); } @Override - protected Extension.MultiRelDetail detailFromExtensionMultiRel(Any any) { + protected Extension.MultiRelDetail detailFromExtensionMultiRel(final Any any) { return unpack(any); } } @@ -198,19 +198,20 @@ protected Extension.MultiRelDetail detailFromExtensionMultiRel(Any any) { static class CustomSubstraitRelNodeConverter extends SubstraitRelNodeConverter { public CustomSubstraitRelNodeConverter( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - RelBuilder relBuilder) { + final SimpleExtension.ExtensionCollection extensions, + final RelDataTypeFactory typeFactory, + final RelBuilder relBuilder) { super(extensions, typeFactory, relBuilder); } @Override - public RelNode visit(ExtensionLeaf extensionLeaf, Context context) { + public RelNode visit(final ExtensionLeaf extensionLeaf, final Context context) { if (extensionLeaf.getDetail() instanceof ColumnAppendDetail) { - ColumnAppendDetail cad = (ColumnAppendDetail) extensionLeaf.getDetail(); - RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); - RelOptCluster cluster = relBuilder.getCluster(); - RelTraitSet traits = cluster.traitSet(); + final ColumnAppendDetail cad = (ColumnAppendDetail) extensionLeaf.getDetail(); + final RexLiteral literal = + (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); + final RelOptCluster cluster = relBuilder.getCluster(); + final RelTraitSet traits = cluster.traitSet(); return new ColumnAppenderRel( relBuilder.getCluster(), traits, literal, Collections.emptyList()); } @@ -218,11 +219,13 @@ public RelNode visit(ExtensionLeaf extensionLeaf, Context context) { } @Override - public RelNode visit(ExtensionSingle extensionSingle, Context context) throws RuntimeException { + public RelNode visit(final ExtensionSingle extensionSingle, final Context context) + throws RuntimeException { if (extensionSingle.getDetail() instanceof ColumnAppendDetail) { - ColumnAppendDetail cad = (ColumnAppendDetail) extensionSingle.getDetail(); - RelNode input = extensionSingle.getInput().accept(this, context); - RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); + final ColumnAppendDetail cad = (ColumnAppendDetail) extensionSingle.getDetail(); + final RelNode input = extensionSingle.getInput().accept(this, context); + final RexLiteral literal = + (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); return new ColumnAppenderRel( input.getCluster(), input.getTraitSet(), literal, List.of(input)); } @@ -230,14 +233,16 @@ public RelNode visit(ExtensionSingle extensionSingle, Context context) throws Ru } @Override - public RelNode visit(ExtensionMulti extensionMulti, Context context) throws RuntimeException { + public RelNode visit(final ExtensionMulti extensionMulti, final Context context) + throws RuntimeException { if (extensionMulti.getDetail() instanceof ColumnAppendDetail) { - ColumnAppendDetail cad = (ColumnAppendDetail) extensionMulti.getDetail(); - List inputs = + final ColumnAppendDetail cad = (ColumnAppendDetail) extensionMulti.getDetail(); + final List inputs = extensionMulti.getInputs().stream() .map(input -> input.accept(this, context)) .collect(Collectors.toList()); - RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); + final RexLiteral literal = + (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); return new ColumnAppenderRel( inputs.get(0).getCluster(), inputs.get(0).getTraitSet(), literal, inputs); } @@ -249,17 +254,18 @@ public RelNode visit(ExtensionMulti extensionMulti, Context context) throws Runt static class CustomSubstraitRelVisitor extends SubstraitRelVisitor { public CustomSubstraitRelVisitor( - RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { + final RelDataTypeFactory typeFactory, + final SimpleExtension.ExtensionCollection extensions) { super(typeFactory, extensions); } @Override - public Rel visitOther(RelNode other) { + public Rel visitOther(final RelNode other) { if (other instanceof ColumnAppenderRel) { - ColumnAppenderRel car = (ColumnAppenderRel) other; - Expression.Literal literal = (Expression.Literal) toExpression(car.literal); - ColumnAppendDetail detail = new ColumnAppendDetail(literal); - List inputs = apply(car.getInputs()); + final ColumnAppenderRel car = (ColumnAppenderRel) other; + final Expression.Literal literal = (Expression.Literal) toExpression(car.literal); + final ColumnAppendDetail detail = new ColumnAppendDetail(literal); + final List inputs = apply(car.getInputs()); if (inputs.isEmpty()) { return ExtensionLeaf.from(detail).build(); @@ -280,7 +286,10 @@ static class ColumnAppenderRel extends AbstractRelNode { final List inputs; public ColumnAppenderRel( - RelOptCluster cluster, RelTraitSet traitSet, RexLiteral literal, List inputs) { + final RelOptCluster cluster, + final RelTraitSet traitSet, + final RexLiteral literal, + final List inputs) { super(cluster, traitSet); this.literal = literal; this.inputs = inputs; @@ -293,11 +302,11 @@ public List getInputs() { @Override protected RelDataType deriveRowType() { - List fields = new ArrayList<>(); - for (RelNode input : getInputs()) { + final List fields = new ArrayList<>(); + for (final RelNode input : getInputs()) { fields.addAll(input.getRowType().getFieldList()); } - RelDataTypeFieldImpl appendedField = + final RelDataTypeFieldImpl appendedField = new RelDataTypeFieldImpl("appended_column", fields.size(), literal.getType()); fields.add(appendedField); return getCluster() diff --git a/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java index 3d4fdda9b..1071264b2 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RoundingFunctionTest.java @@ -10,22 +10,22 @@ class RoundingFunctionTest extends PlanTestBase { @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void ceil(String column) throws Exception { - String query = String.format("SELECT ceil(%s) FROM numbers", column); + void ceil(final String column) throws Exception { + final String query = String.format("SELECT ceil(%s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"fp32", "fp64"}) - void floor(String column) throws Exception { - String query = String.format("SELECT floor(%s) FROM numbers", column); + void floor(final String column) throws Exception { + final String query = String.format("SELECT floor(%s) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"i8", "i16", "i32", "i64", "fp32", "fp64"}) - void round(String column) throws Exception { - String query = String.format("SELECT round(%s, 2) FROM numbers", column); + void round(final String column) throws Exception { + final String query = String.format("SELECT round(%s, 2) FROM numbers", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java b/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java index 2869ad797..e5889f81c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SchemaCollectorTest.java @@ -24,22 +24,22 @@ class SchemaCollectorTest extends PlanTestBase { SubstraitBuilder b = substraitBuilder; SchemaCollector schemaCollector = new SchemaCollector(typeFactory, TypeConverter.DEFAULT); - void hasTable(CalciteSchema schema, String tableName, String tableSchema) { - CalciteSchema.TableEntry table = schema.getTable(tableName, false); + void hasTable(final CalciteSchema schema, final String tableName, final String tableSchema) { + final CalciteSchema.TableEntry table = schema.getTable(tableName, false); assertNotNull(table); assertEquals(tableSchema, table.getTable().getRowType(typeFactory).getFullTypeString()); } @Test void canCollectTables() { - Rel rel = + final Rel rel = b.cross( b.namedScan( List.of("table1"), List.of("col1", "col2", "col3"), List.of(N.I64, R.FP64, N.STRING)), b.namedScan(List.of("table2"), List.of("col4", "col5"), List.of(N.BOOLEAN, N.I32))); - CalciteSchema calciteSchema = schemaCollector.toSchema(rel); + final CalciteSchema calciteSchema = schemaCollector.toSchema(rel); hasTable( calciteSchema, @@ -50,7 +50,7 @@ void canCollectTables() { @Test void canCollectTablesInSchemas() { - Rel rel = + final Rel rel = b.namedWrite( List.of("schema3", "table4"), List.of("col1", "col2", "col3", "col4", "col5", "col6"), @@ -68,23 +68,23 @@ void canCollectTablesInSchemas() { List.of("col4", "col5"), List.of(N.BOOLEAN, N.I32))), b.namedScan(List.of("schema2", "table3"), List.of("col6"), List.of(N.I64)))); - CalciteSchema calciteSchema = schemaCollector.toSchema(rel); + final CalciteSchema calciteSchema = schemaCollector.toSchema(rel); - CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); + final CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); hasTable(schema1, "table1", "RecordType(BIGINT col1, DOUBLE col2, VARCHAR col3) NOT NULL"); hasTable(schema1, "table2", "RecordType(BOOLEAN col4, INTEGER col5) NOT NULL"); - CalciteSchema schema2 = calciteSchema.getSubSchema("schema2", false); + final CalciteSchema schema2 = calciteSchema.getSubSchema("schema2", false); hasTable(schema2, "table3", "RecordType(BIGINT col6) NOT NULL"); - CalciteSchema schema3 = calciteSchema.getSubSchema("schema3", false); + final CalciteSchema schema3 = calciteSchema.getSubSchema("schema3", false); hasTable( schema3, "table4", "RecordType(BIGINT col1, DOUBLE col2, VARCHAR col3, BOOLEAN col4, INTEGER col5, BIGINT col6) NOT NULL"); } - private static Expression.ScalarFunctionInvocation fnAdd(int value) { + private static Expression.ScalarFunctionInvocation fnAdd(final int value) { return DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions().stream() .filter(s -> s.name().equalsIgnoreCase("add")) .findFirst() @@ -104,15 +104,15 @@ private static Expression.ScalarFunctionInvocation fnAdd(int value) { @Test void testUpdate() { - List transformations = + final List transformations = Arrays.asList( NamedUpdate.TransformExpression.builder() .columnTarget(0) .transformation(fnAdd(1)) .build()); - Expression condition = ExpressionCreator.bool(false, true); + final Expression condition = ExpressionCreator.bool(false, true); - Rel rel = + final Rel rel = b.namedWrite( List.of("schema1", "table2"), List.of("col1"), @@ -122,8 +122,8 @@ void testUpdate() { b.namedUpdate( List.of("schema1", "table1"), List.of("col1"), transformations, condition, true)); - CalciteSchema calciteSchema = schemaCollector.toSchema(rel); - CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); + final CalciteSchema calciteSchema = schemaCollector.toSchema(rel); + final CalciteSchema schema1 = calciteSchema.getSubSchema("schema1", false); hasTable(schema1, "table1", "RecordType(BOOLEAN col1)"); hasTable(schema1, "table2", "RecordType(BOOLEAN col1)"); @@ -131,40 +131,40 @@ void testUpdate() { @Test void canHandleMultipleSchemas() { - Rel rel = + final Rel rel = b.cross( b.namedScan( List.of("level1", "level2a", "level3", "t1"), List.of("col1"), List.of(N.I64)), b.namedScan(List.of("level1", "level2b", "t2"), List.of("col2"), List.of(N.I32))); - CalciteSchema rootSchema = schemaCollector.toSchema(rel); - CalciteSchema level1 = rootSchema.getSubSchema("level1", false); + final CalciteSchema rootSchema = schemaCollector.toSchema(rel); + final CalciteSchema level1 = rootSchema.getSubSchema("level1", false); - CalciteSchema level2a = level1.getSubSchema("level2a", false); - CalciteSchema level3 = level2a.getSubSchema("level3", false); + final CalciteSchema level2a = level1.getSubSchema("level2a", false); + final CalciteSchema level3 = level2a.getSubSchema("level3", false); hasTable(level3, "t1", "RecordType(BIGINT col1) NOT NULL"); - CalciteSchema level2b = level1.getSubSchema("level2b", false); + final CalciteSchema level2b = level1.getSubSchema("level2b", false); hasTable(level2b, "t2", "RecordType(INTEGER col2) NOT NULL"); } @Test void canHandleDuplicateNamedScans() { - Rel table = b.namedScan(List.of("table"), List.of("col1"), List.of(N.BOOLEAN)); - Rel rel = b.cross(table, table); + final Rel table = b.namedScan(List.of("table"), List.of("col1"), List.of(N.BOOLEAN)); + final Rel rel = b.cross(table, table); - CalciteSchema calciteSchema = schemaCollector.toSchema(rel); + final CalciteSchema calciteSchema = schemaCollector.toSchema(rel); hasTable(calciteSchema, "table", "RecordType(BOOLEAN col1) NOT NULL"); } @Test void validatesSchemasForDuplicateNamedScans() { - Rel rel = + final Rel rel = b.cross( b.namedScan(List.of("t"), List.of("col1"), List.of(N.BOOLEAN)), b.namedScan(List.of("t"), List.of("col1"), List.of(R.BOOLEAN))); - IllegalArgumentException exception = + final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> schemaCollector.toSchema(rel)); assertEquals( "NamedScan for [t] is present multiple times with different schemas", @@ -173,12 +173,12 @@ void validatesSchemasForDuplicateNamedScans() { @Test void validatesSchemasForNestedDuplicateNamedScans() { - Rel rel = + final Rel rel = b.cross( b.namedScan(List.of("s", "t"), List.of("col1"), List.of(N.BOOLEAN)), b.namedScan(List.of("s", "t"), List.of("col1"), List.of(R.BOOLEAN))); - IllegalArgumentException exception = + final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> schemaCollector.toSchema(rel)); assertEquals( "NamedScan for [s, t] is present multiple times with different schemas", diff --git a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java index 90c20b93e..a3516a38e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SimpleExtendedExpressionsTest.java @@ -26,15 +26,15 @@ private static Stream expressionTypeProvider() { @ParameterizedTest @MethodSource("expressionTypeProvider") - void testExtendedExpressionsRoundTrip(String sqlExpression) + void testExtendedExpressionsRoundTrip(final String sqlExpression) throws SqlParseException, IOException { assertProtoExtendedExpressionRoundtrip(sqlExpression); } @ParameterizedTest @MethodSource("expressionTypeProvider") - void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sqlExpression) { - IllegalArgumentException illegalArgumentException = + void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(final String sqlExpression) { + final IllegalArgumentException illegalArgumentException = assertThrows( IllegalArgumentException.class, () -> assertProtoExtendedExpressionRoundtrip(sqlExpression, "tpch/schema_error.sql")); @@ -46,7 +46,7 @@ void testExtendedExpressionsDuplicateColumnIdentifierRoundTrip(String sqlExpress @Test void testExtendedExpressionsListExpressionRoundTrip() throws SqlParseException, IOException { - String[] expressions = { + final String[] expressions = { "2", "L_ORDERKEY", "L_ORDERKEY > 10", diff --git a/isthmus/src/test/java/io/substrait/isthmus/StringFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/StringFunctionTest.java index 6bd56da25..a87f642ce 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/StringFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/StringFunctionTest.java @@ -18,134 +18,134 @@ final class StringFunctionTest extends PlanTestBase { @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void charLength(String column) throws Exception { - String query = String.format("SELECT char_length(%s) FROM strings", column); + void charLength(final String column) throws Exception { + final String query = String.format("SELECT char_length(%s) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"vc32"}) - void concat(String column) throws Exception { - String query = String.format("SELECT %s || %s FROM strings", column, column); + void concat(final String column) throws Exception { + final String query = String.format("SELECT %s || %s FROM strings", column, column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void lower(String column) throws Exception { - String query = String.format("SELECT lower(%s) FROM strings", column); + void lower(final String column) throws Exception { + final String query = String.format("SELECT lower(%s) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void upper(String column) throws Exception { - String query = String.format("SELECT upper(%s) FROM strings", column); + void upper(final String column) throws Exception { + final String query = String.format("SELECT upper(%s) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void replace(String column) throws Exception { - String query = + void replace(final String column) throws Exception { + final String query = String.format("SELECT replace(%s, replace_from, replace_to) FROM replace_strings", column); assertSqlSubstraitRelRoundTrip(query, REPLACE_CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void substringWith1Param(String column) throws Exception { - String query = String.format("SELECT substring(%s, 42) FROM strings", column); + void substringWith1Param(final String column) throws Exception { + final String query = String.format("SELECT substring(%s, 42) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void substringWith2Params(String column) throws Exception { - String query = String.format("SELECT substring(%s, 42, 5) FROM strings", column); + void substringWith2Params(final String column) throws Exception { + final String query = String.format("SELECT substring(%s, 42, 5) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void substringFrom(String column) throws Exception { - String query = String.format("SELECT substring(%s FROM 42) FROM strings", column); + void substringFrom(final String column) throws Exception { + final String query = String.format("SELECT substring(%s FROM 42) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32"}) - void substringFromFor(String column) throws Exception { - String query = String.format("SELECT substring(%s FROM 42 FOR 5) FROM strings", column); + void substringFromFor(final String column) throws Exception { + final String query = String.format("SELECT substring(%s FROM 42 FOR 5) FROM strings", column); assertSqlSubstraitRelRoundTrip(query, CREATES); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trim(String column) throws Exception { - String query = String.format("SELECT TRIM(%s) FROM strings", column); + void trim(final String column) throws Exception { + final String query = String.format("SELECT TRIM(%s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimSpecifiedCharacter(String column) throws Exception { - String query = String.format("SELECT TRIM(' ' FROM %s) FROM strings", column); + void trimSpecifiedCharacter(final String column) throws Exception { + final String query = String.format("SELECT TRIM(' ' FROM %s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimBoth(String column) throws Exception { - String query = String.format("SELECT TRIM(BOTH FROM %s) FROM strings", column); + void trimBoth(final String column) throws Exception { + final String query = String.format("SELECT TRIM(BOTH FROM %s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimBothSpecifiedCharacter(String column) throws Exception { - String query = String.format("SELECT TRIM(BOTH ' ' FROM %s) FROM strings", column); + void trimBothSpecifiedCharacter(final String column) throws Exception { + final String query = String.format("SELECT TRIM(BOTH ' ' FROM %s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimLeading(String column) throws Exception { - String query = String.format("SELECT TRIM(LEADING FROM %s) FROM strings", column); + void trimLeading(final String column) throws Exception { + final String query = String.format("SELECT TRIM(LEADING FROM %s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimLeadingSpecifiedCharacter(String column) throws Exception { - String query = String.format("SELECT TRIM(LEADING ' ' FROM %s) FROM strings", column); + void trimLeadingSpecifiedCharacter(final String column) throws Exception { + final String query = String.format("SELECT TRIM(LEADING ' ' FROM %s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimTrailing(String column) throws Exception { - String query = String.format("SELECT TRIM(TRAILING FROM %s) FROM strings", column); + void trimTrailing(final String column) throws Exception { + final String query = String.format("SELECT TRIM(TRAILING FROM %s) FROM strings", column); assertSqlRoundTrip(query); } @ParameterizedTest @ValueSource(strings = {"c16", "vc32", "vc"}) - void trimTrailingSpecifiedCharacter(String column) throws Exception { - String query = String.format("SELECT TRIM(TRAILING ' ' FROM %s) FROM strings", column); + void trimTrailingSpecifiedCharacter(final String column) throws Exception { + final String query = String.format("SELECT TRIM(TRAILING ' ' FROM %s) FROM strings", column); assertSqlRoundTrip(query); } - private void assertSqlRoundTrip(String sql) throws SqlParseException { - Plan plan = assertProtoPlanRoundrip(sql, new SqlToSubstrait(), CREATES); + private void assertSqlRoundTrip(final String sql) throws SqlParseException { + final Plan plan = assertProtoPlanRoundrip(sql, new SqlToSubstrait(), CREATES); assertDoesNotThrow(() -> toSql(plan), "Substrait plan to SQL"); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testStarts_With(String left, String right) throws Exception { + void testStarts_With(final String left, final String right) throws Exception { - String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right); + final String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @@ -154,16 +154,16 @@ void testStarts_With(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testStarts_WithLiteral(String left, String right) throws Exception { - String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right); + void testStarts_WithLiteral(final String left, final String right) throws Exception { + final String query = String.format("SELECT STARTS_WITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testStartsWith(String left, String right) throws Exception { + void testStartsWith(final String left, final String right) throws Exception { - String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right); + final String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @@ -172,16 +172,16 @@ void testStartsWith(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testStartsWithLiteral(String left, String right) throws Exception { - String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right); + void testStartsWithLiteral(final String left, final String right) throws Exception { + final String query = String.format("SELECT STARTSWITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testEnds_With(String left, String right) throws Exception { + void testEnds_With(final String left, final String right) throws Exception { - String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right); + final String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @@ -190,16 +190,16 @@ void testEnds_With(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testEnds_WithLiteral(String left, String right) throws Exception { - String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right); + void testEnds_WithLiteral(final String left, final String right) throws Exception { + final String query = String.format("SELECT ENDS_WITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testEndsWith(String left, String right) throws Exception { + void testEndsWith(final String left, final String right) throws Exception { - String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right); + final String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @@ -208,16 +208,16 @@ void testEndsWith(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testEndsWithLiteral(String left, String right) throws Exception { - String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right); + void testEndsWithLiteral(final String left, final String right) throws Exception { + final String query = String.format("SELECT ENDSWITH(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testContains(String left, String right) throws Exception { + void testContains(final String left, final String right) throws Exception { - String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right); + final String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @@ -226,18 +226,18 @@ void testContains(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testContainsWithLiteral(String left, String right) throws Exception { + void testContainsWithLiteral(final String left, final String right) throws Exception { - String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right); + final String query = String.format("SELECT CONTAINS_SUBSTR(%s, %s) FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testPosition(String left, String right) throws Exception { + void testPosition(final String left, final String right) throws Exception { - String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right); + final String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right); assertSqlRoundTrip(query); } @@ -246,18 +246,18 @@ void testPosition(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testPositionWithLiteral(String left, String right) throws Exception { + void testPositionWithLiteral(final String left, final String right) throws Exception { - String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right); + final String query = String.format("SELECT POSITION(%s IN %s) > 0 FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"c16, c16", "c16, vc32", "c16, vc", "vc32, vc32", "vc32, vc", "vc, vc"}) - void testStrpos(String left, String right) throws Exception { + void testStrpos(final String left, final String right) throws Exception { - String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right); + final String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right); assertSqlRoundTrip(query); } @@ -266,36 +266,36 @@ void testStrpos(String left, String right) throws Exception { @CsvSource( value = {"'start', vc", "vc, 'end'"}, quoteCharacter = '`') - void testStrposWithLiteral(String left, String right) throws Exception { + void testStrposWithLiteral(final String left, final String right) throws Exception { - String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right); + final String query = String.format("SELECT STRPOS(%s, %s) > 0 FROM strings", left, right); assertSqlRoundTrip(query); } @ParameterizedTest @CsvSource({"vc32, i32", "vc, i32"}) - void testLeft(String left, String right) throws Exception { + void testLeft(final String left, final String right) throws Exception { - String query = String.format("SELECT LEFT(%s, %s) FROM int_num_strings", left, right); + final String query = String.format("SELECT LEFT(%s, %s) FROM int_num_strings", left, right); assertSqlSubstraitRelRoundTrip(query, CHAR_INT_CREATES); } @ParameterizedTest @CsvSource({"vc32, i32", "vc, i32"}) - void testRight(String left, String right) throws Exception { + void testRight(final String left, final String right) throws Exception { - String query = String.format("SELECT RIGHT(%s, %s) FROM int_num_strings", left, right); + final String query = String.format("SELECT RIGHT(%s, %s) FROM int_num_strings", left, right); assertSqlSubstraitRelRoundTrip(query, CHAR_INT_CREATES); } @ParameterizedTest @CsvSource({"vc32, i32, vc32", "vc, i32, vc"}) - void testRpad(String left, String center, String right) throws Exception { + void testRpad(final String left, final String center, final String right) throws Exception { - String query = + final String query = String.format("SELECT RPAD(%s, %s, %s) FROM int_num_strings", left, center, right); assertSqlSubstraitRelRoundTrip(query, CHAR_INT_CREATES); @@ -303,9 +303,9 @@ void testRpad(String left, String center, String right) throws Exception { @ParameterizedTest @CsvSource({"vc32, i32, vc32", "vc, i32, vc"}) - void testLpad(String left, String center, String right) throws Exception { + void testLpad(final String left, final String center, final String right) throws Exception { - String query = + final String query = String.format("SELECT LPAD(%s, %s, %s) FROM int_num_strings", left, center, right); assertSqlSubstraitRelRoundTrip(query, CHAR_INT_CREATES); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java index 9e9aad657..99b07bd09 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java @@ -19,14 +19,14 @@ class SubqueryPlanTest extends PlanTestBase { @Test void existsCorrelatedSubquery() throws SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - Plan plan = + final SqlToSubstrait s = new SqlToSubstrait(); + final Plan plan = toProto( s.convert( "select l_partkey from lineitem where exists (select o_orderdate from orders where o_orderkey = l_orderkey)", TPCH_CATALOG)); - Expression.Subquery subquery = + final Expression.Subquery subquery = plan.getRelations(0) .getRoot() .getInput() @@ -39,13 +39,13 @@ void existsCorrelatedSubquery() throws SqlParseException { assertTrue(subquery.hasSetPredicate()); assertSame(PredicateOp.PREDICATE_OP_EXISTS, subquery.getSetPredicate().getPredicateOp()); - FilterRel setPredicateFilter = + final FilterRel setPredicateFilter = subquery .getSetPredicate() .getTuples() .getFilter(); // exits (select ... from orders where o_orderkey = l_orderkey) - Expression.FieldReference correlatedCol = + final Expression.FieldReference correlatedCol = setPredicateFilter .getCondition() .getScalarFunction() @@ -59,14 +59,14 @@ void existsCorrelatedSubquery() throws SqlParseException { @Test void uniqueCorrelatedSubquery() throws IOException, SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - Plan plan = + final SqlToSubstrait s = new SqlToSubstrait(); + final Plan plan = toProto( s.convert( "select l_partkey from lineitem where unique (select o_orderdate from orders where o_orderkey = l_orderkey)", TPCH_CATALOG)); - Expression.Subquery subquery = + final Expression.Subquery subquery = plan.getRelations(0) .getRoot() .getInput() @@ -77,7 +77,7 @@ void uniqueCorrelatedSubquery() throws IOException, SqlParseException { .getSubquery(); assertTrue(subquery.hasSetPredicate()); - FilterRel setPredicateFilter = + final FilterRel setPredicateFilter = subquery .getSetPredicate() .getTuples() @@ -88,7 +88,7 @@ void uniqueCorrelatedSubquery() throws IOException, SqlParseException { assertTrue(subquery.hasSetPredicate()); assertSame(PredicateOp.PREDICATE_OP_UNIQUE, subquery.getSetPredicate().getPredicateOp()); - Expression.FieldReference correlatedCol = + final Expression.FieldReference correlatedCol = setPredicateFilter .getCondition() .getScalarFunction() @@ -102,12 +102,12 @@ void uniqueCorrelatedSubquery() throws IOException, SqlParseException { @Test void inPredicateCorrelatedSubQuery() throws IOException, SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - String sql = + final SqlToSubstrait s = new SqlToSubstrait(); + final String sql = "select l_orderkey from lineitem where l_partkey in (select p_partkey from part where p_partkey = l_partkey)"; - Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); + final Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); - Expression.Subquery subquery = + final Expression.Subquery subquery = plan.getRelations(0) .getRoot() .getInput() @@ -118,7 +118,7 @@ void inPredicateCorrelatedSubQuery() throws IOException, SqlParseException { .getSubquery(); assertTrue(subquery.hasInPredicate()); - FilterRel insubqueryFilter = + final FilterRel insubqueryFilter = subquery .getInPredicate() .getHaystack() @@ -126,7 +126,7 @@ void inPredicateCorrelatedSubQuery() throws IOException, SqlParseException { .getInput() .getFilter(); // p_partkey = l_partkey - Expression.FieldReference correlatedCol = + final Expression.FieldReference correlatedCol = insubqueryFilter .getCondition() .getScalarFunction() @@ -140,11 +140,11 @@ void inPredicateCorrelatedSubQuery() throws IOException, SqlParseException { @Test void notInPredicateCorrelatedSubquery() throws IOException, SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - String sql = + final SqlToSubstrait s = new SqlToSubstrait(); + final String sql = "select l_orderkey from lineitem where l_partkey not in (select p_partkey from part where p_partkey = l_partkey)"; - Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); - Expression.Subquery subquery = + final Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); + final Expression.Subquery subquery = plan.getRelations(0) .getRoot() .getInput() @@ -158,7 +158,7 @@ void notInPredicateCorrelatedSubquery() throws IOException, SqlParseException { .getSubquery(); assertTrue(subquery.hasInPredicate()); - FilterRel insubqueryFilter = + final FilterRel insubqueryFilter = subquery .getInPredicate() .getHaystack() @@ -166,7 +166,7 @@ void notInPredicateCorrelatedSubquery() throws IOException, SqlParseException { .getInput() .getFilter(); // p_partkey = l_partkey - Expression.FieldReference correlatedCol = + final Expression.FieldReference correlatedCol = insubqueryFilter .getCondition() .getScalarFunction() @@ -180,8 +180,8 @@ void notInPredicateCorrelatedSubquery() throws IOException, SqlParseException { @Test void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - String sql = + final SqlToSubstrait s = new SqlToSubstrait(); + final String sql = "SELECT p_partkey\n" + "FROM part p\n" + "WHERE EXISTS\n" @@ -193,9 +193,9 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { + " FROM partsupp ps\n" + " WHERE ps.ps_partkey = p.p_partkey\n" + " AND PS.ps_suppkey = l.l_suppkey))"; - Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); + final Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); - Expression.Subquery outer_subquery = + final Expression.Subquery outer_subquery = plan.getRelations(0) .getRoot() .getInput() @@ -208,19 +208,19 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { assertTrue(outer_subquery.hasSetPredicate()); assertSame(PredicateOp.PREDICATE_OP_EXISTS, outer_subquery.getSetPredicate().getPredicateOp()); - FilterRel exists_filter = + final FilterRel exists_filter = outer_subquery .getSetPredicate() .getTuples() .getFilter(); // l.l_partkey = p.p_partkey and unique (...) - Expression.Subquery inner_subquery = + final Expression.Subquery inner_subquery = exists_filter.getCondition().getScalarFunction().getArguments(1).getValue().getSubquery(); assertTrue(inner_subquery.hasSetPredicate()); assertSame(PredicateOp.PREDICATE_OP_UNIQUE, inner_subquery.getSetPredicate().getPredicateOp()); - Expression inner_subquery_condition = + final Expression inner_subquery_condition = inner_subquery .getSetPredicate() .getTuples() @@ -229,18 +229,18 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { .getFilter() .getCondition(); - Expression inner_subquery_cond1 = + final Expression inner_subquery_cond1 = inner_subquery_condition .getScalarFunction() .getArguments(0) .getValue(); // ps.ps_partkey = p.p_partkey - Expression inner_subquery_cond2 = + final Expression inner_subquery_cond2 = inner_subquery_condition .getScalarFunction() .getArguments(1) .getValue(); // PS.ps_suppkey = l.l_suppkey - Expression.FieldReference correlatedCol1 = + final Expression.FieldReference correlatedCol1 = inner_subquery_cond1 .getScalarFunction() .getArguments(1) @@ -249,7 +249,7 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { assertEquals(0, correlatedCol1.getDirectReference().getStructField().getField()); assertEquals(2, correlatedCol1.getOuterReference().getStepsOut()); - Expression.FieldReference correlatedCol2 = + final Expression.FieldReference correlatedCol2 = inner_subquery_cond2 .getScalarFunction() .getArguments(1) @@ -261,14 +261,14 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { @Test void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { - SqlToSubstrait s = new SqlToSubstrait(); - String sql = asString("subquery/nested_scalar_subquery_in_filter.sql"); - Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); - String planText = JsonFormat.printer().includingDefaultValueFields().print(plan); + final SqlToSubstrait s = new SqlToSubstrait(); + final String sql = asString("subquery/nested_scalar_subquery_in_filter.sql"); + final Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); + final String planText = JsonFormat.printer().includingDefaultValueFields().print(plan); System.out.println(planText); - Expression.Subquery outer_subquery = + final Expression.Subquery outer_subquery = plan.getRelations(0) .getRoot() .getInput() @@ -283,7 +283,7 @@ void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { assertTrue(outer_subquery.hasScalar()); - Expression.Subquery inner_subquery = + final Expression.Subquery inner_subquery = outer_subquery .getScalar() .getInput() @@ -301,21 +301,21 @@ void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { .getValue() .getSubquery(); - Expression inner_subquery_condition = + final Expression inner_subquery_condition = inner_subquery.getScalar().getInput().getAggregate().getInput().getFilter().getCondition(); - Expression inner_subquery_cond1 = + final Expression inner_subquery_cond1 = inner_subquery_condition .getScalarFunction() .getArguments(0) .getValue(); // ps.ps_partkey = p.p_partkey - Expression inner_subquery_cond2 = + final Expression inner_subquery_cond2 = inner_subquery_condition .getScalarFunction() .getArguments(1) .getValue(); // PS.ps_suppkey = l.l_suppkey - Expression.FieldReference correlatedCol1 = + final Expression.FieldReference correlatedCol1 = inner_subquery_cond1 .getScalarFunction() .getArguments(1) @@ -324,7 +324,7 @@ void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { assertEquals(0, correlatedCol1.getDirectReference().getStructField().getField()); assertEquals(2, correlatedCol1.getOuterReference().getStepsOut()); - Expression.FieldReference correlatedCol2 = + final Expression.FieldReference correlatedCol2 = inner_subquery_cond2 .getScalarFunction() .getArguments(1) @@ -336,7 +336,7 @@ void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { @Test void correlatedScalarSubQueryInSelect() throws Exception { - String sql = asString("subquery/nested_scalar_subquery_in_select.sql"); + final String sql = asString("subquery/nested_scalar_subquery_in_select.sql"); assertSqlSubstraitRelRoundTrip(sql); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java index c67e172bc..86b0bcc28 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/Substrait2SqlTest.java @@ -16,24 +16,24 @@ import org.junit.jupiter.params.provider.MethodSource; class Substrait2SqlTest extends PlanTestBase { - private void assertSqlRoundTripViaPojoAndProto(String inputSql) { - Plan plan = + private void assertSqlRoundTripViaPojoAndProto(final String inputSql) { + final Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql, TPCH_CATALOG), "SQL to Substrait POJO"); assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL"); - io.substrait.proto.Plan proto = + final io.substrait.proto.Plan proto = assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO"); assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL"); } @Test void simpleTest() throws Exception { - String query = "select p_size from part where p_partkey > cast(100 as bigint)"; + final String query = "select p_size from part where p_partkey > cast(100 as bigint)"; assertSqlSubstraitRelRoundTrip(query); } @Test void simpleTest2() throws Exception { - String query = + final String query = "select l_partkey, l_discount from lineitem where l_orderkey > cast(100 as bigint)"; assertSqlSubstraitRelRoundTrip(query); } @@ -136,7 +136,7 @@ void simpleTestAgg3() throws Exception { @ParameterizedTest @MethodSource("io.substrait.isthmus.utils.SetUtils#setTestConfig") - void setTest(Set.SetOp op, boolean multi) throws Exception { + void setTest(final Set.SetOp op, final boolean multi) throws Exception { assertSqlSubstraitRelRoundTrip(SetUtils.getSetQuery(op, multi)); } @@ -166,13 +166,13 @@ void tpch_q1_variant() throws Exception { @Test void simpleTestApproxCountDistinct() throws Exception { - String query = "select approx_count_distinct(l_tax) from lineitem"; - RelRoot relRoot = assertSqlSubstraitRelRoundTrip(query); - RelNode relNode = relRoot.project(); + final String query = "select approx_count_distinct(l_tax) from lineitem"; + final RelRoot relRoot = assertSqlSubstraitRelRoundTrip(query); + final RelNode relNode = relRoot.project(); // Assert converted Calcite RelNode has `approx_count_distinct` assertInstanceOf(LogicalAggregate.class, relNode); - LogicalAggregate aggregate = (LogicalAggregate) relNode; + final LogicalAggregate aggregate = (LogicalAggregate) relNode; assertEquals( SqlStdOperatorTable.APPROX_COUNT_DISTINCT, aggregate.getAggCallList().get(0).getAggregation()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java index d2ffd60ef..964f8b10a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java @@ -55,90 +55,90 @@ public SubstraitExpressionConverterTest() { @Test void switchExpression() { - Switch expr = + final Switch expr = b.switchExpression( b.fieldReference(commonTable, 0), List.of(b.switchClause(b.i32(0), b.fieldReference(commonTable, 3))), b.bool(false)); - RexNode calciteExpr = expr.accept(converter, Context.newContext()); + final RexNode calciteExpr = expr.accept(converter, Context.newContext()); assertTypeMatch(calciteExpr.getType(), N.BOOLEAN); } @Test void scalarSubQuery() { - Rel subQueryRel = createSubQueryRel(); + final Rel subQueryRel = createSubQueryRel(); - Expression.ScalarSubquery expr = + final Expression.ScalarSubquery expr = Expression.ScalarSubquery.builder().type(R.I64).input(subQueryRel).build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + final Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); - RelNode calciteRel = substraitToCalcite.convert(query); + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + final RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); - List calciteProjectExpr = ((LogicalProject) calciteRel).getProjects(); + final List calciteProjectExpr = ((LogicalProject) calciteRel).getProjects(); assertEquals(1, calciteProjectExpr.size()); assertEquals(SqlKind.SCALAR_QUERY, calciteProjectExpr.get(0).getKind()); } @Test void existsSetPredicate() { - Rel subQueryRel = createSubQueryRel(); + final Rel subQueryRel = createSubQueryRel(); - Expression.SetPredicate expr = + final Expression.SetPredicate expr = Expression.SetPredicate.builder() .predicateOp(Expression.PredicateOp.PREDICATE_OP_EXISTS) .tuples(subQueryRel) .build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + final Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); - RelNode calciteRel = substraitToCalcite.convert(query); + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + final RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); - List calciteProjectExpr = ((LogicalProject) calciteRel).getProjects(); + final List calciteProjectExpr = ((LogicalProject) calciteRel).getProjects(); assertEquals(1, calciteProjectExpr.size()); assertEquals(SqlKind.EXISTS, calciteProjectExpr.get(0).getKind()); } @Test void uniqueSetPredicate() { - Rel subQueryRel = createSubQueryRel(); + final Rel subQueryRel = createSubQueryRel(); - Expression.SetPredicate expr = + final Expression.SetPredicate expr = Expression.SetPredicate.builder() .predicateOp(Expression.PredicateOp.PREDICATE_OP_UNIQUE) .tuples(subQueryRel) .build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + final Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); - RelNode calciteRel = substraitToCalcite.convert(query); + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + final RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); - List calciteProjectExpr = ((LogicalProject) calciteRel).getProjects(); + final List calciteProjectExpr = ((LogicalProject) calciteRel).getProjects(); assertEquals(1, calciteProjectExpr.size()); assertEquals(SqlKind.UNIQUE, calciteProjectExpr.get(0).getKind()); } @Test void unspecifiedSetPredicate() { - Rel subQueryRel = createSubQueryRel(); + final Rel subQueryRel = createSubQueryRel(); - Expression.SetPredicate expr = + final Expression.SetPredicate expr = Expression.SetPredicate.builder() .predicateOp(Expression.PredicateOp.PREDICATE_OP_UNSPECIFIED) .tuples(subQueryRel) .build(); - Project query = b.project(input -> List.of(expr), b.emptyScan()); + final Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); - Exception exception = + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + final Exception exception = assertThrows( UnsupportedOperationException.class, () -> { @@ -166,7 +166,7 @@ Rel createSubQueryRel() { @Test void useSubstraitReturnTypeDuringScalarFunctionConversion() { - Expression.ScalarFunctionInvocation expr = + final Expression.ScalarFunctionInvocation expr = b.scalarFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "add:i32_i32", @@ -176,13 +176,13 @@ void useSubstraitReturnTypeDuringScalarFunctionConversion() { b.i32(7), b.i32(42)); - RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); assertEquals(TypeConverter.DEFAULT.toCalcite(typeFactory, R.FP32), calciteExpr.getType()); } @Test void useSubstraitReturnTypeDuringWindowFunctionConversion() { - Expression.WindowFunctionInvocation expr = + final Expression.WindowFunctionInvocation expr = b.windowFn( DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "row_number:", @@ -196,12 +196,12 @@ void useSubstraitReturnTypeDuringWindowFunctionConversion() { WindowBound.UNBOUNDED, b.i32(42)); - RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); + final RexNode calciteExpr = expr.accept(expressionRexConverter, Context.newContext()); assertEquals(TypeConverter.DEFAULT.toCalcite(typeFactory, R.STRING), calciteExpr.getType()); } - void assertTypeMatch(RelDataType actual, Type expected) { - Type type = TypeConverter.DEFAULT.toSubstrait(actual); + void assertTypeMatch(final RelDataType actual, final Type expected) { + final Type type = TypeConverter.DEFAULT.toSubstrait(actual); assertEquals(expected, type); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java index e9cc9e02a..a38c37075 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java @@ -37,20 +37,20 @@ class SubstraitRelNodeConverterTest extends PlanTestBase { class Aggregate { @Test void direct() { - Plan.Root root = + final Plan.Root root = b.root( b.aggregate( input -> b.grouping(input, 0, 2), input -> List.of(b.count(input, 0)), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.I64); } @Test void emit() { - Plan.Root root = + final Plan.Root root = b.root( b.aggregate( input -> b.grouping(input, 0, 2), @@ -58,7 +58,7 @@ void emit() { b.remap(1, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, R.I64); } } @@ -67,17 +67,17 @@ void emit() { class Cross { @Test void direct() { - Plan.Root root = b.root(b.cross(commonTable, commonTable)); + final Plan.Root root = b.root(b.cross(commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableTypeTwice); } @Test void emit() { - Plan.Root root = b.root(b.cross(commonTable, commonTable, b.remap(0, 1, 4, 6))); + final Plan.Root root = b.root(b.cross(commonTable, commonTable, b.remap(0, 1, 4, 6))); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, R.FP32, R.I32, N.STRING); } } @@ -86,17 +86,17 @@ void emit() { class Fetch { @Test void direct() { - Plan.Root root = b.root(b.fetch(20, 40, commonTable)); + final Plan.Root root = b.root(b.fetch(20, 40, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = b.root(b.fetch(20, 40, b.remap(0, 2), commonTable)); + final Plan.Root root = b.root(b.fetch(20, 40, b.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -105,17 +105,17 @@ void emit() { class Filter { @Test void direct() { - Plan.Root root = b.root(b.filter(input -> b.bool(true), commonTable)); + final Plan.Root root = b.root(b.filter(input -> b.bool(true), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = b.root(b.filter(input -> b.bool(true), b.remap(0, 2), commonTable)); + final Plan.Root root = b.root(b.filter(input -> b.bool(true), b.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -124,18 +124,18 @@ void emit() { class Join { @Test void direct() { - Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), commonTable, commonTable)); + final Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableTypeTwice); } @Test void emit() { - Plan.Root root = + final Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), b.remap(0, 6), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } @@ -144,14 +144,14 @@ void leftJoin() { final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); - Plan.Root root = + final Plan.Root root = b.root( b.project( r -> b.fieldReferences(r, 0, 1, 3), b.remap(6, 7, 8), b.join(ji -> b.bool(true), JoinType.LEFT, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.STRING, R.FP64, N.STRING); } @@ -160,14 +160,14 @@ void rightJoin() { final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); - Plan.Root root = + final Plan.Root root = b.root( b.project( r -> b.fieldReferences(r, 0, 1, 3), b.remap(6, 7, 8), b.join(ji -> b.bool(true), JoinType.RIGHT, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, R.STRING); } @@ -176,14 +176,14 @@ void outerJoin() { final List joinTableType = List.of(R.STRING, R.FP64, R.BINARY); final Rel joinTable = b.namedScan(List.of("join"), List.of("a", "b", "c"), joinTableType); - Plan.Root root = + final Plan.Root root = b.root( b.project( r -> b.fieldReferences(r, 0, 1, 3), b.remap(6, 7, 8), b.join(ji -> b.bool(true), JoinType.OUTER, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, N.STRING); } } @@ -192,21 +192,21 @@ void outerJoin() { class NamedScan { @Test void direct() { - Plan.Root root = + final Plan.Root root = b.root(b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32))); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, R.FP32); } @Test void emit() { - Plan.Root root = + final Plan.Root root = b.root( b.namedScan( List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32), b.remap(1))); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.FP32); } } @@ -215,21 +215,22 @@ void emit() { class Project { @Test void direct() { - Plan.Root root = b.root(b.project(input -> b.fieldReferences(input, 1, 0, 2), commonTable)); + final Plan.Root root = + b.root(b.project(input -> b.fieldReferences(input, 1, 0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch( relNode.getRowType(), R.I32, R.FP32, N.STRING, N.BOOLEAN, R.FP32, R.I32, N.STRING); } @Test void emit() { - Plan.Root root = + final Plan.Root root = b.root( b.project( input -> b.fieldReferences(input, 1, 0, 2), b.remap(0, 2, 4, 6), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.FP32, N.STRING); } } @@ -238,17 +239,18 @@ void emit() { class Set { @Test void direct() { - Plan.Root root = b.root(b.set(SetOp.UNION_ALL, commonTable, commonTable)); + final Plan.Root root = b.root(b.set(SetOp.UNION_ALL, commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = b.root(b.set(SetOp.UNION_ALL, b.remap(0, 2), commonTable, commonTable)); + final Plan.Root root = + b.root(b.set(SetOp.UNION_ALL, b.remap(0, 2), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -257,18 +259,18 @@ void emit() { class Sort { @Test void direct() { - Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), commonTable)); + final Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @Test void emit() { - Plan.Root root = + final Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), b.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -278,26 +280,26 @@ class EmptyScan { @Test void direct() { - Rel emptyScan = + final Rel emptyScan = io.substrait.relation.EmptyScan.builder() .initialSchema(NamedStruct.of(Collections.emptyList(), R.struct(R.I32, N.STRING))) .build(); - Plan.Root root = b.root(emptyScan); - RelNode relNode = converter.convert(root.getInput()); + final Plan.Root root = b.root(emptyScan); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), List.of(R.I32, N.STRING)); } @Test void emit() { - Rel emptyScanWithRemap = + final Rel emptyScanWithRemap = io.substrait.relation.EmptyScan.builder() .initialSchema(NamedStruct.of(Collections.emptyList(), R.struct(R.I32, N.STRING))) .remap(Rel.Remap.of(List.of(0))) .build(); - Plan.Root root = b.root(emptyScanWithRemap); - RelNode relNode = converter.convert(root.getInput()); + final Plan.Root root = b.root(emptyScanWithRemap); + final RelNode relNode = converter.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java index d5d8ada75..9dde11b0c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java @@ -17,28 +17,28 @@ class SubstraitToCalciteTest extends PlanTestBase { @Test void testConvertRootSingleColumn() { - Iterable types = List.of(TypeCreator.REQUIRED.STRING); - Root root = + final Iterable types = List.of(TypeCreator.REQUIRED.STRING); + final Root root = Root.builder() .input(substraitBuilder.namedScan(List.of("stores"), List.of("s"), types)) .addNames("store") .build(); - RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = converter.convert(root); assertEquals(root.getNames(), relRoot.fields.rightList()); } @Test void testConvertRootMultipleColumns() { - Iterable types = List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); - Root root = + final Iterable types = List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); + final Root root = Root.builder() .input(substraitBuilder.namedScan(List.of("stores"), List.of("s_store_id", "s"), types)) .addNames("s_store_id", "store") .build(); - RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = converter.convert(root); assertEquals(root.getNames(), relRoot.fields.rightList()); } @@ -47,8 +47,8 @@ void testConvertRootMultipleColumns() { void testConvertRootStructField() { final Type structType = TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); - Iterable types = List.of(structType); - Root root = + final Iterable types = List.of(structType); + final Root root = Root.builder() .input( substraitBuilder.namedScan( @@ -58,7 +58,7 @@ void testConvertRootStructField() { assertEquals(List.of("store", "store_id", "store_name"), root.getNames()); - RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = converter.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -66,7 +66,7 @@ void testConvertRootStructField() { // the sub field names are stored within RelRoot.validatedRowType assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames()); - RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType(); + final RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType(); assertEquals(List.of("store_id", "store_name"), storeFieldDataType.getFieldNames()); } @@ -75,8 +75,8 @@ void testConvertRootArrayWithStructField() { final Type structType = TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); final Type arrayType = TypeCreator.REQUIRED.list(structType); - Set types = Set.of(arrayType); - Root root = + final Set types = Set.of(arrayType); + final Root root = Root.builder() .input( substraitBuilder.namedScan( @@ -84,7 +84,7 @@ void testConvertRootArrayWithStructField() { .addNames("store", "store_id", "store_name") .build(); - RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = converter.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -92,7 +92,7 @@ void testConvertRootArrayWithStructField() { // the hierarchical structure is stored within RelRoot.validatedRowType assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames()); - RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType(); + final RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType(); assertEquals(SqlTypeName.ARRAY, storeFieldDataType.getSqlTypeName()); final RelDataType arrayElementType = storeFieldDataType.getComponentType(); @@ -105,8 +105,8 @@ void testConvertRootMapWithStructValues() { final Type structType = TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); final Type mapValueType = TypeCreator.REQUIRED.map(TypeCreator.REQUIRED.I64, structType); - Set types = Set.of(mapValueType); - Root root = + final Set types = Set.of(mapValueType); + final Root root = Root.builder() .input( substraitBuilder.namedScan( @@ -135,8 +135,8 @@ void testConvertRootMapWithStructKeys() { final Type structType = TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING); final Type mapKeyType = TypeCreator.REQUIRED.map(structType, TypeCreator.REQUIRED.I64); - Set types = Set.of(mapKeyType); - Root root = + final Set types = Set.of(mapKeyType); + final Root root = Root.builder() .input( substraitBuilder.namedScan( @@ -144,7 +144,7 @@ void testConvertRootMapWithStructKeys() { .addNames("store", "store_id", "store_name") .build(); - RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = converter.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -152,7 +152,7 @@ void testConvertRootMapWithStructKeys() { // the hierarchical structure is stored within RelRoot.validatedRowType assertEquals(List.of("store"), relRoot.validatedRowType.getFieldNames()); - RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType(); + final RelDataType storeFieldDataType = relRoot.validatedRowType.getFieldList().get(0).getType(); assertEquals(SqlTypeName.MAP, storeFieldDataType.getSqlTypeName()); final RelDataType mapKeyDataType = storeFieldDataType.getKeyType(); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java index a3f234c47..59e48884e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java @@ -35,7 +35,7 @@ public class Employee { public final int DEPT_ID; public final String NAME; - public Employee(int deptId, String name) { + public Employee(final int deptId, final String name) { this.DEPT_ID = deptId; this.NAME = name; } @@ -157,7 +157,7 @@ protected CustomRelBuilder( super(context, cluster, relOptSchema); } - public static CustomRelBuilder create(FrameworkConfig config) { + public static CustomRelBuilder create(final FrameworkConfig config) { return Frameworks.withPrepare( config, (cluster, relOptSchema, rootSchema, statement) -> @@ -165,8 +165,8 @@ public static CustomRelBuilder create(FrameworkConfig config) { } public CustomRelBuilder repeat(final int repeatCount) { - RelNode input = this.peek(); - RelNode repeatNode = RepeatRel.create(input, repeatCount); + final RelNode input = this.peek(); + final RelNode repeatNode = RepeatRel.create(input, repeatCount); this.push(repeatNode); return this; } diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java index 119222e57..b70b30a05 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java @@ -24,16 +24,16 @@ static IntStream testCases() { */ @ParameterizedTest @MethodSource("testCases") - void testQuery(int query) throws IOException { - String inputSql = asString(inputSqlFile(query)); - Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait POJO"); + void testQuery(final int query) throws IOException { + final String inputSql = asString(inputSqlFile(query)); + final Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait POJO"); assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL"); - io.substrait.proto.Plan proto = + final io.substrait.proto.Plan proto = assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO"); assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL"); } - private String inputSqlFile(int query) { + private String inputSqlFile(final int query) { if (alternateForms.contains(query)) { return String.format("tpcds/queries/%02da.sql", query); } @@ -41,7 +41,7 @@ private String inputSqlFile(int query) { return String.format("tpcds/queries/%02d.sql", query); } - private Plan toSubstraitPlan(String sql) throws SqlParseException { + private Plan toSubstraitPlan(final String sql) throws SqlParseException { return toSubstraitPlan(sql, TPCDS_CATALOG); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java index 4b2423201..ec50e4601 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java @@ -21,20 +21,20 @@ static IntStream testCases() { */ @ParameterizedTest @MethodSource("testCases") - void testQuery(int query) throws IOException { - String inputSql = asString(String.format("tpch/queries/%02d.sql", query)); + void testQuery(final int query) throws IOException { + final String inputSql = asString(String.format("tpch/queries/%02d.sql", query)); - Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait POJO"); + final Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait POJO"); assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL"); - io.substrait.proto.Plan proto = + final io.substrait.proto.Plan proto = assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO"); assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL"); } - private Plan toSubstraitPlan(String sql) throws SqlParseException { + private Plan toSubstraitPlan(final String sql) throws SqlParseException { return toSubstraitPlan(sql, TPCH_CATALOG); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java index acf6942da..13936831d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/WindowFunctionTest.java @@ -36,8 +36,8 @@ void lead() throws IOException, SqlParseException { @ParameterizedTest @ValueSource(strings = {"rank", "dense_rank", "percent_rank"}) - void rankFunctions(String rankFunction) throws IOException, SqlParseException { - String query = + void rankFunctions(final String rankFunction) throws IOException, SqlParseException { + final String query = String.format( "select O_ORDERKEY, %s() over (order by O_SHIPPRIORITY) from ORDERS", rankFunction); assertFullRoundTrip(query); @@ -45,8 +45,9 @@ void rankFunctions(String rankFunction) throws IOException, SqlParseException { @ParameterizedTest @ValueSource(strings = {"rank", "dense_rank", "percent_rank"}) - void rankFunctionsWithPartitions(String rankFunction) throws IOException, SqlParseException { - String query = + void rankFunctionsWithPartitions(final String rankFunction) + throws IOException, SqlParseException { + final String query = String.format( "select O_ORDERKEY, %s() over (partition by O_CUSTKEY order by O_SHIPPRIORITY) from ORDERS", rankFunction); @@ -88,7 +89,8 @@ void unboundedPreceding() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS UNBOUNDED PRECEDING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows unbounded preceding"; + final String overClause = + "partition by O_CUSTKEY order by O_ORDERDATE rows unbounded preceding"; assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); } @@ -99,7 +101,7 @@ void unboundedFollowing() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MAX($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClaus = + final String overClaus = "partition by O_CUSTKEY order by O_ORDERDATE rows between current row AND unbounded following"; assertFullRoundTrip( String.format("select max(O_SHIPPRIORITY) over (%s) from ORDERS", overClaus)); @@ -111,7 +113,7 @@ void rowsPrecedingToCurrent() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS 1 PRECEDING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClause = + final String overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows between 1 preceding and current row"; assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); @@ -123,7 +125,7 @@ void currentToRowsFollowing() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MAX($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClause = + final String overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows between current row and 2 following"; assertFullRoundTrip( String.format("select max(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); @@ -135,7 +137,7 @@ void rowsPrecedingAndFollowing() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $4 ROWS BETWEEN 3 PRECEDING AND 4 FOLLOWING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClause = + final String overClause = "partition by O_CUSTKEY order by O_ORDERDATE rows between 3 preceding and 4 following"; assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); @@ -147,7 +149,7 @@ void rangePrecedingToCurrent() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $3 RANGE 10 PRECEDING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClause = + final String overClause = "partition by O_CUSTKEY order by O_TOTALPRICE range between 10 preceding and current row"; assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); @@ -159,7 +161,7 @@ void rangeCurrentToFollowing() throws IOException, SqlParseException { LogicalProject(EXPR$0=[MIN($7) OVER (PARTITION BY $1 ORDER BY $3 RANGE BETWEEN CURRENT ROW AND 11 FOLLOWING)]) LogicalTableScan(table=[[ORDERS]]) */ - String overClause = + final String overClause = "partition by O_CUSTKEY order by O_TOTALPRICE range between current row and 11 following"; assertFullRoundTrip( String.format("select min(O_SHIPPRIORITY) over (%s) from ORDERS", overClause)); @@ -171,7 +173,8 @@ class AggregateFunctionInvocations { @ParameterizedTest @ValueSource(strings = {"avg", "count", "max", "min", "sum"}) - void standardAggregateFunctions(String aggFunction) throws SqlParseException, IOException { + void standardAggregateFunctions(final String aggFunction) + throws SqlParseException, IOException { assertFullRoundTrip( String.format( "select %s(L_LINENUMBER) over (partition BY L_PARTKEY) from lineitem", aggFunction)); @@ -182,14 +185,14 @@ void standardAggregateFunctions(String aggFunction) throws SqlParseException, IO void rejectQueriesWithIgnoreNulls() { // IGNORE NULLS cannot be specified in the Substrait representation. // Queries using it should be rejected. - String query = "select last_value(L_LINENUMBER) ignore nulls over () from lineitem"; + final String query = "select last_value(L_LINENUMBER) ignore nulls over () from lineitem"; assertThrows(IllegalArgumentException.class, () -> assertFullRoundTrip(query)); } @ParameterizedTest @ValueSource(strings = {"lag", "lead"}) - void lagLeadFunctions(String function) { - Rel rel = + void lagLeadFunctions(final String function) { + final Rel rel = substraitBuilder.project( input -> List.of( @@ -211,8 +214,8 @@ void lagLeadFunctions(String function) { @ParameterizedTest @ValueSource(strings = {"lag", "lead"}) - void lagLeadWithOffset(String function) { - Rel rel = + void lagLeadWithOffset(final String function) { + final Rel rel = substraitBuilder.project( input -> List.of( @@ -235,8 +238,8 @@ void lagLeadWithOffset(String function) { @ParameterizedTest @ValueSource(strings = {"lag", "lead"}) - void lagLeadWithOffsetAndDefault(String function) { - Rel rel = + void lagLeadWithOffsetAndDefault(final String function) { + final Rel rel = substraitBuilder.project( input -> List.of( diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java index d39cc2084..050fee8ef 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java @@ -17,11 +17,11 @@ class AggregateFunctionConverterTest extends PlanTestBase { @Test void testFunctionFinderMatch() { - AggregateFunctionConverter converter = + final AggregateFunctionConverter converter = new AggregateFunctionConverter( extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); - FunctionFinder functionFinder = + final FunctionFinder functionFinder = converter.getFunctionFinder( AggregateCall.create( new SqlSumEmptyIsZeroAggFunction(), diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java b/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java index 3e57905e2..705d8eda0 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/SetUtils.java @@ -17,10 +17,10 @@ private SetUtils() {} * @param multi whether to use more than two relations * @return a sql query */ - public static String getSetQuery(Set.SetOp op, boolean multi) { - String opString = asString(op); + public static String getSetQuery(final Set.SetOp op, final boolean multi) { + final String opString = asString(op); - StringBuilder query = new StringBuilder(); + final StringBuilder query = new StringBuilder(); query.append( "select p_partkey as partkey, p_name as str, (p_partkey + p_partkey) as expr\n" + "from part where p_partkey > cast(100 as bigint)\n"); @@ -39,7 +39,7 @@ public static String getSetQuery(Set.SetOp op, boolean multi) { } } - private static String asString(Set.SetOp op) { + private static String asString(final Set.SetOp op) { switch (op) { case MINUS_PRIMARY: return "EXCEPT"; diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 2c90f133d..fe80b4607 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -16,14 +16,14 @@ public class UserTypeFactory { private final String urn; private final String name; - public UserTypeFactory(String urn, String name) { + public UserTypeFactory(final String urn, final String name) { this.urn = urn; this.name = name; this.N = new InnerType(true, name); this.R = new InnerType(false, name); } - public RelDataType createCalcite(boolean nullable) { + public RelDataType createCalcite(final boolean nullable) { if (nullable) { return N; } else { @@ -31,11 +31,11 @@ public RelDataType createCalcite(boolean nullable) { } } - public Type createSubstrait(boolean nullable) { + public Type createSubstrait(final boolean nullable) { return TypeCreator.of(nullable).userDefined(urn, name); } - public boolean isTypeFromFactory(RelDataType type) { + public boolean isTypeFromFactory(final RelDataType type) { return type == N || type == R; } @@ -43,7 +43,7 @@ private static class InnerType extends RelDataTypeImpl { private final boolean nullable; private final String name; - private InnerType(boolean nullable, String name) { + private InnerType(final boolean nullable, final String name) { computeDigest(); this.nullable = nullable; this.name = name; @@ -60,7 +60,7 @@ public SqlTypeName getSqlTypeName() { } @Override - protected void generateTypeString(StringBuilder sb, boolean withDetail) { + protected void generateTypeString(final StringBuilder sb, final boolean withDetail) { sb.append(name); } }