diff --git a/pom.xml b/pom.xml index 07e04b9ade472..045624ff224f6 100644 --- a/pom.xml +++ b/pom.xml @@ -222,6 +222,7 @@ presto-hudi presto-native-execution presto-native-tests + presto-native-tvf presto-router presto-open-telemetry redis-hbo-provider @@ -1191,6 +1192,12 @@ ${project.version} + + com.facebook.presto + presto-native-tvf + ${project.version} + + com.facebook.hive hive-dwrf diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java index 15c33950ce14b..630f4670f6cc2 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java @@ -86,14 +86,6 @@ public Field(Optional nodeLocation, Optional relati this.aliased = aliased; } - public static Field newUnqualified(Optional name, Type type) - { - requireNonNull(name, "name is null"); - requireNonNull(type, "type is null"); - - return new Field(Optional.empty(), Optional.empty(), name, type, false, Optional.empty(), Optional.empty(), false); - } - public Optional getNodeLocation() { return nodeLocation; diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java index e3c930926895b..09c820ca6e697 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java @@ -46,7 +46,7 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.GenericTableReturnTypeSpecification.GENERIC_TABLE; import static java.util.Objects.requireNonNull; public class QueryFunctionProvider diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java b/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java index 347665bba084e..ad8be60b3f665 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/RowType.java @@ -19,6 +19,8 @@ import com.facebook.presto.common.block.BlockBuilderStatus; import com.facebook.presto.common.block.RowBlockBuilder; import com.facebook.presto.common.function.SqlFunctionProperties; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; import java.util.Collections; @@ -41,13 +43,26 @@ public class RowType private final Optional typeSignature; private RowType(List fields) + { + this(fields, fields.stream() + .map(Field::getType) + .collect(toList())); + } + + private RowType(List fields, List fieldTypes) + { + this(fields, fieldTypes, containsDistinctType(fieldTypes) ? Optional.empty() : Optional.of(makeSignature(fields))); + } + + @JsonCreator + public RowType(List fields, + List fieldTypes, + @JsonProperty("typeSignature") Optional typeSignature) { super(Block.class); this.fields = fields; - this.fieldTypes = fields.stream() - .map(Field::getType) - .collect(toList()); - this.typeSignature = containsDistinctType(this.fieldTypes) ? Optional.empty() : Optional.of(makeSignature(fields)); + this.fieldTypes = fieldTypes; + this.typeSignature = typeSignature; } public static RowType from(List fields) @@ -120,6 +135,12 @@ public BlockBuilder createBlockBuilder(BlockBuilderStatus blockBuilderStatus, in return new RowBlockBuilder(getTypeParameters(), blockBuilderStatus, expectedEntries); } + @JsonProperty + public List getFields() + { + return fields; + } + @Override public String getDisplayName() { @@ -193,18 +214,13 @@ public List getTypeParameters() return fieldTypes; } - public List getFields() - { - return fields; - } - public static class Field { private final Type type; private final Optional name; private final boolean delimited; - public Field(Optional name, Type type) + public Field(@JsonProperty("name") Optional name, @JsonProperty("type") Type type) { this(name, type, false); } @@ -216,11 +232,13 @@ public Field(Optional name, Type type, boolean delimited) this.delimited = delimited; } + @JsonProperty public Type getType() { return type; } + @JsonProperty public Optional getName() { return name; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java index a45f03d7c0618..19ac259c84f33 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java @@ -52,6 +52,8 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.BaseProcedure; import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; @@ -86,6 +88,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId; @@ -215,6 +218,12 @@ public synchronized void addConnectorFactory(ConnectorFactory connectorFactory) ConnectorFactory existingConnectorFactory = connectorFactories.putIfAbsent(connectorFactory.getName(), connectorFactory); checkArgument(existingConnectorFactory == null, "Connector %s is already registered", connectorFactory.getName()); handleResolver.addConnectorName(connectorFactory.getName(), connectorFactory.getHandleResolver()); + connectorFactory.getTableFunctionHandleResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionNamespace(connectorFactory.getName(), resolver); + }); + connectorFactory.getTableFunctionSplitResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionSplitNamespace(connectorFactory.getName(), resolver); + }); } public synchronized ConnectorId createConnection(String catalogName, String connectorName, Map properties) @@ -334,6 +343,7 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) metadataManager.getAnalyzePropertyManager().addProperties(connectorId, connector.getAnalyzeProperties()); metadataManager.getSessionPropertyManager().addConnectorSessionProperties(connectorId, connector.getSessionProperties()); metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().addTableFunctions(connectorId, connector.getTableFunctions()); + metadataManager.getFunctionAndTypeManager().addTableFunctionProcessorProvider(connectorId, connector.getTableFunctionProcessorProvider()); } public synchronized void dropConnection(String catalogName) @@ -346,6 +356,7 @@ public synchronized void dropConnection(String catalogName) removeConnectorInternal(createInformationSchemaConnectorId(connectorId)); removeConnectorInternal(createSystemTablesConnectorId(connectorId)); metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().removeTableFunctions(connectorId); + metadataManager.getFunctionAndTypeManager().removeTableFunctionProcessorProvider(connectorId); }); } @@ -422,6 +433,7 @@ private static class MaterializedConnector private final Set> functions; private final Set connectorTableFunctions; + private final Function connectorTableFunctionProcessorProvider; private final ConnectorPageSourceProvider pageSourceProvider; private final Optional pageSinkProvider; private final Optional indexProvider; @@ -459,6 +471,7 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) Set connectorTableFunctions = connector.getTableFunctions(); requireNonNull(connectorTableFunctions, format("Connector '%s' returned a null table functions set", connectorId)); this.connectorTableFunctions = ImmutableSet.copyOf(connectorTableFunctions); + this.connectorTableFunctionProcessorProvider = connector.getTableFunctionProcessorProvider(); ConnectorPageSourceProvider connectorPageSourceProvider = null; try { @@ -660,5 +673,10 @@ public Set getTableFunctions() { return connectorTableFunctions; } + + public Function getTableFunctionProcessorProvider() + { + return connectorTableFunctionProcessorProvider; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java index 860f6e47a6285..e2f82f59ca499 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java @@ -15,17 +15,22 @@ import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayout; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.ConnectorTableLayoutResult; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; import com.facebook.presto.spi.SplitContext; @@ -34,6 +39,9 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.transaction.IsolationLevel; import com.facebook.presto.transaction.InternalConnector; @@ -45,6 +53,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import static java.util.Objects.requireNonNull; @@ -56,12 +65,18 @@ public class GlobalSystemConnector private final String connectorId; private final Set systemTables; private final Set procedures; + private final Set tableFunctions; + private final NodeManager nodeManager; + private final FunctionAndTypeManager functionAndTypeManager; - public GlobalSystemConnector(String connectorId, Set systemTables, Set procedures) + public GlobalSystemConnector(String connectorId, Set systemTables, Set procedures, Set tableFunctions, NodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.systemTables = ImmutableSet.copyOf(requireNonNull(systemTables, "systemTables is null")); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.tableFunctions = ImmutableSet.copyOf(requireNonNull(tableFunctions, "tableFunctions is null")); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); } @Override @@ -138,8 +153,18 @@ public Map> listTableColumns(ConnectorSess @Override public ConnectorSplitManager getSplitManager() { - return (transactionHandle, session, layout, splitSchedulingContext) -> { - throw new UnsupportedOperationException(); + return new ConnectorSplitManager() { + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle function) + { + return function.getSplits(transaction, session, nodeManager, functionAndTypeManager); + } }; } @@ -166,4 +191,24 @@ public Set getProcedures() { return procedures; } + + @Override + public Set getTableFunctions() + { + return tableFunctions; + } + + @Override + public Function getTableFunctionProcessorProvider() + { + return connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof ExcludeColumns.ExcludeColumnsFunctionHandle) { + return ExcludeColumns.getExcludeColumnsFunctionProcessorProvider(); + } + else if (connectorTableFunctionHandle instanceof Sequence.SequenceFunctionHandle) { + return Sequence.getSequenceFunctionProcessorProvider(); + } + return null; + }; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java index 3d1c8e329188b..8cb92a83e9003 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java @@ -13,11 +13,14 @@ */ package com.facebook.presto.connector.system; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableSet; import jakarta.inject.Inject; @@ -32,12 +35,18 @@ public class GlobalSystemConnectorFactory { private final Set tables; private final Set procedures; + private final Set tableFunctions; + private final NodeManager nodeManager; + private final FunctionAndTypeManager functionAndTypeManager; @Inject - public GlobalSystemConnectorFactory(Set tables, Set procedures) + public GlobalSystemConnectorFactory(Set tables, Set procedures, Set tableFunctions, NodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager) { this.tables = ImmutableSet.copyOf(requireNonNull(tables, "tables is null")); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.tableFunctions = ImmutableSet.copyOf(requireNonNull(tableFunctions, "tableFunctions is null")); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); } @Override @@ -55,6 +64,6 @@ public ConnectorHandleResolver getHandleResolver() @Override public Connector create(String catalogName, Map config, ConnectorContext context) { - return new GlobalSystemConnector(catalogName, tables, procedures); + return new GlobalSystemConnector(catalogName, tables, procedures, tableFunctions, nodeManager, functionAndTypeManager); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java index 19728ef156e78..ded582255460e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java @@ -27,7 +27,13 @@ import com.facebook.presto.connector.system.jdbc.TableTypeJdbcTable; import com.facebook.presto.connector.system.jdbc.TypesJdbcTable; import com.facebook.presto.connector.system.jdbc.UdtJdbcTable; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; +import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.SystemTable; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableMap; import com.google.inject.Binder; @@ -77,6 +83,13 @@ public void configure(Binder binder) binder.bind(GlobalSystemConnectorFactory.class).in(Scopes.SINGLETON); binder.bind(SystemConnectorRegistrar.class).asEagerSingleton(); + binder.bind(PluginNodeManager.class).in(Scopes.SINGLETON); + binder.bind(NodeManager.class).to(PluginNodeManager.class).in(Scopes.SINGLETON); + binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + + Multibinder tableFunctions = Multibinder.newSetBinder(binder, ConnectorTableFunction.class); + tableFunctions.addBinding().toProvider(ExcludeColumns.class).in(Scopes.SINGLETON); + tableFunctions.addBinding().toProvider(Sequence.class).in(Scopes.SINGLETON); } @ProvidesIntoSet diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 426011d14447e..f7e690c015860 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -649,7 +649,7 @@ private PlanRoot runCreateLogicalPlanAsync() private void createQueryScheduler(PlanRoot plan) { - CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits); + CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager); // ensure split sources are closed stateMachine.addStateChangeListener(state -> { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java index 899552d6feedf..0268c52825ce1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java @@ -1081,7 +1081,10 @@ public ListenableFuture processFor(Duration duration) @Override public String getInfo() { - return (partitionedSplit == null) ? "" : partitionedSplit.getSplit().getInfo().toString(); + if (partitionedSplit != null && partitionedSplit.getSplit() != null && partitionedSplit.getSplit().getInfo() != null) { + return partitionedSplit.getSplit().getInfo().toString(); + } + return ""; } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index 2781b84a63240..a8fa369a1a3d3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -34,6 +34,7 @@ import com.facebook.presto.common.type.TypeWithName; import com.facebook.presto.common.type.UserDefinedType; import com.facebook.presto.operator.window.WindowFunctionSupplier; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.StandardErrorCode; @@ -56,6 +57,12 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlFunctionSupplier; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionMetadata; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.tvf.TVFProvider; +import com.facebook.presto.spi.tvf.TVFProviderContext; +import com.facebook.presto.spi.tvf.TVFProviderFactory; import com.facebook.presto.spi.type.TypeManagerContext; import com.facebook.presto.spi.type.TypeManagerFactory; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -92,6 +99,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.regex.Pattern; import static com.facebook.presto.SystemSessionProperties.isExperimentalFunctionsEnabled; @@ -105,6 +113,7 @@ import static com.facebook.presto.metadata.FunctionSignatureMatcher.decideAndThrow; import static com.facebook.presto.metadata.SessionFunctionHandle.SESSION_NAMESPACE; import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; @@ -160,9 +169,13 @@ public class FunctionAndTypeManager private final AtomicReference>> servingTypeManagerParametricTypesSupplier; private final BuiltInWorkerFunctionNamespaceManager builtInWorkerFunctionNamespaceManager; private final BuiltInPluginFunctionNamespaceManager builtInPluginFunctionNamespaceManager; + private final ConcurrentHashMap> tableFunctionProcessorProviderMap = new ConcurrentHashMap<>(); private final FunctionsConfig functionsConfig; private final Set types; + private final Map tvfProviderFactories = new ConcurrentHashMap<>(); + private final Map tvfProviders = new ConcurrentHashMap<>(); + @Inject public FunctionAndTypeManager( TransactionManager transactionManager, @@ -375,6 +388,39 @@ public void addFunctionNamespace(String catalogName, FunctionNamespaceManager fu } } + public void loadTVFProvider(String tvfProviderName, NodeManager nodeManager) + { + requireNonNull(tvfProviderName, "tvfProviderName is null"); + TVFProviderFactory factory = tvfProviderFactories.get(tvfProviderName); + checkState(factory != null, "No factory for tvf provider %s", tvfProviderName); + TVFProvider tvfProvider = factory.createTVFProvider(ImmutableMap.of(), new TVFProviderContext(nodeManager, this)); + + if (tvfProviders.putIfAbsent(new ConnectorId(tvfProviderName), tvfProvider) != null) { + throw new IllegalArgumentException(format("TVF provider [%s] is already registered", tvfProvider)); + } + } + + public void loadTVFProviders(NodeManager nodeManager) + { + for (String tvfProviderName : tvfProviderFactories.keySet()) { + loadTVFProvider(tvfProviderName, nodeManager); + } + } + + public void addTVFProviderFactory(TVFProviderFactory factory) + { + if (tvfProviderFactories.putIfAbsent(factory.getName(), factory) != null) { + throw new IllegalArgumentException(format("TVF provider '%s' is already registered", factory.getName())); + } + handleResolver.addTableFunctionNamespace(factory.getName(), factory.getTableFunctionHandleResolver()); + handleResolver.addTableFunctionSplitNamespace(factory.getName(), factory.getTableFunctionSplitResolver()); + } + + public HandleResolver getHandleResolver() + { + return handleResolver; + } + @Override public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) { @@ -485,6 +531,23 @@ public void addTypeManagerFactory(TypeManagerFactory factory) } } + public TableFunctionMetadata resolveTableFunction(Session session, QualifiedName qualifiedName) + { + // Before resolving the table function, add all the TVF provider's table functions to the function registry. + if (!tableFunctionRegistry.areTvfProviderFunctionsLoaded()) { + for (ConnectorId connectorId : tvfProviders.keySet()) { + // In terms of the NativeTVFProvider, you want it to act similarly to the system connector table functions, hence we replace the Java loaded system connector table functions. + // This is only enforced when native execution is enabled and presto-native-tvf module is loaded. + if (connectorId.getCatalogName().equals("system")) { + tableFunctionRegistry.removeTableFunctions(connectorId); + } + tableFunctionRegistry.addTableFunctions(connectorId, tvfProviders.get(connectorId).getTableFunctions()); + } + tableFunctionRegistry.updateTvfProviderFunctionsLoaded(); + } + return tableFunctionRegistry.resolve(session, qualifiedName); + } + public TransactionManager getTransactionManager() { return transactionManager; @@ -704,6 +767,24 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHand return functionNamespaceManager.get().getScalarFunctionImplementation(functionHandle); } + public TableFunctionProcessorProvider getTableFunctionProcessorProvider(TableFunctionHandle tableFunctionHandle) + { + return tableFunctionProcessorProviderMap.get(tableFunctionHandle.getConnectorId()).apply(tableFunctionHandle.getFunctionHandle()); + } + + public void addTableFunctionProcessorProvider(ConnectorId connectorId, Function tableFunctionProcessorProvider) + { + if (tableFunctionProcessorProviderMap.putIfAbsent(connectorId, tableFunctionProcessorProvider) != null) { + throw new PrestoException(ALREADY_EXISTS, + format("TableFuncitonProcessorProvider already exists for connectorId %s. Overwriting is not supported.", connectorId.getCatalogName())); + } + } + + public void removeTableFunctionProcessorProvider(ConnectorId connectorId) + { + tableFunctionProcessorProviderMap.remove(connectorId); + } + public AggregationFunctionImplementation getAggregateFunctionImplementation(FunctionHandle functionHandle) { if (isBuiltInPluginFunctionHandle(functionHandle)) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index c46f717b8524c..3d4097afa2094 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -39,6 +39,7 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(TransactionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(PartitioningHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(TableFunctionJacksonHandleModule.class); binder.bind(HandleResolver.class).in(Scopes.SINGLETON); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java index 1541a98ee6bf7..7f717568ba7b1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java @@ -15,6 +15,8 @@ import com.facebook.presto.connector.informationSchema.InformationSchemaHandleResolver; import com.facebook.presto.connector.system.SystemHandleResolver; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; @@ -30,12 +32,18 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.split.EmptySplitHandleResolver; +import com.google.common.collect.ImmutableSet; import jakarta.inject.Inject; +import java.util.Map; import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; @@ -50,6 +58,8 @@ public class HandleResolver { private final ConcurrentMap handleResolvers = new ConcurrentHashMap<>(); private final ConcurrentMap functionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap> tableFunctionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap> tableFunctionSplitResolvers = new ConcurrentHashMap<>(); @Inject public HandleResolver() @@ -61,6 +71,17 @@ public HandleResolver() functionHandleResolvers.put("$static", new MaterializedFunctionHandleResolver(new BuiltInFunctionNamespaceHandleResolver())); functionHandleResolvers.put("$session", new MaterializedFunctionHandleResolver(new SessionFunctionHandleResolver())); + + tableFunctionHandleResolvers.put( + "$system", + new MaterializedResolver<>(() -> ImmutableSet.of( + ExcludeColumns.ExcludeColumnsFunctionHandle.class, + Sequence.SequenceFunctionHandle.class))); + + tableFunctionSplitResolvers.put( + "$system", + new MaterializedResolver<>(() -> + ImmutableSet.of(Sequence.SequenceFunctionSplit.class))); } public void addConnectorName(String name, ConnectorHandleResolver resolver) @@ -72,6 +93,32 @@ public void addConnectorName(String name, ConnectorHandleResolver resolver) "Connector '%s' is already assigned to resolver: %s", name, existingResolver); } + public void addTableFunctionNamespace(String name, TableFunctionHandleResolver resolver) + { + addNamespace(name, resolver::getTableFunctionHandleClasses, tableFunctionHandleResolvers); + } + + public void addTableFunctionSplitNamespace(String name, TableFunctionSplitResolver resolver) + { + addNamespace(name, resolver::getTableFunctionSplitClasses, tableFunctionSplitResolvers); + } + + private void addNamespace( + String name, + Supplier>> classSupplier, + ConcurrentMap> resolverMap) + { + requireNonNull(name, "name is null"); + requireNonNull(classSupplier, "classSupplier is null"); + + MaterializedResolver newResolver = new MaterializedResolver<>(classSupplier); + MaterializedResolver existingResolver = resolverMap.putIfAbsent(name, newResolver); + + checkState( + existingResolver == null || existingResolver.equals(newResolver), + "Name %s is already assigned to table function resolver: %s", name, existingResolver); + } + public void addFunctionNamespace(String name, FunctionHandleResolver resolver) { requireNonNull(name, "name is null"); @@ -98,7 +145,13 @@ public String getId(ColumnHandle columnHandle) public String getId(ConnectorSplit split) { - return getId(split, MaterializedHandleResolver::getSplitClass); + try { + return getId(split, MaterializedHandleResolver::getSplitClass); + } + catch (Exception e) { + // Fallback if needed + return getFunctionId(split, tableFunctionSplitResolvers); + } } public String getId(ConnectorIndexHandle indexHandle) @@ -146,6 +199,11 @@ public String getId(ConnectorMergeTableHandle mergeHandle) return getId(mergeHandle, MaterializedHandleResolver::getMergeTableHandleClass); } + public String getId(ConnectorTableFunctionHandle tableFunctionHandle) + { + return getFunctionId(tableFunctionHandle, tableFunctionHandleResolvers); + } + public Class getTableHandleClass(String id) { return resolverFor(id).getTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -163,7 +221,17 @@ public Class getColumnHandleClass(String id) public Class getSplitClass(String id) { - return resolverFor(id).getSplitClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + for (Entry> entry : tableFunctionSplitResolvers.entrySet()) { + MaterializedResolver resolver = entry.getValue(); + Optional> tableFunctionSplit = resolver.getClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionSplit.isPresent()) { + return tableFunctionSplit.get(); + } + } + return resolverFor(id).getSplitClass() + .orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } public Class getIndexHandleClass(String id) @@ -211,6 +279,20 @@ public Class getFunctionHandleClass(String id) return resolverForFunctionNamespace(id).getFunctionHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getTableFunctionHandleClass(String id) + { + for (Entry> entry : tableFunctionHandleResolvers.entrySet()) { + MaterializedResolver resolver = entry.getValue(); + Optional> tableFunctionHandle = resolver.getClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionHandle.isPresent()) { + return tableFunctionHandle.get(); + } + } + throw new IllegalArgumentException("No handle resolver for table function namespace: " + id); + } + private MaterializedHandleResolver resolverFor(String id) { MaterializedHandleResolver resolver = handleResolvers.get(id); @@ -253,6 +335,26 @@ private String getFunctionNamespaceId(T handle, Function String getFunctionId( + T handle, + Map> resolvers) + { + for (Entry> entry : resolvers.entrySet()) { + try { + Optional id = entry.getValue().getClasses().stream() + .filter(clazz -> clazz.isInstance(handle)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + catch (UnsupportedOperationException ignored) { + } + } + throw new IllegalArgumentException("No function namespace for instance: " + handle); + } + private static class MaterializedHandleResolver { private final Optional> tableHandle; @@ -267,6 +369,7 @@ private static class MaterializedHandleResolver private final Optional> distributedProcedureHandle; private final Optional> partitioningHandle; private final Optional> transactionHandle; + private final Optional> tableFunctionHandle; public MaterializedHandleResolver(ConnectorHandleResolver resolver) { @@ -282,6 +385,7 @@ public MaterializedHandleResolver(ConnectorHandleResolver resolver) partitioningHandle = getHandleClass(resolver::getPartitioningHandleClass); transactionHandle = getHandleClass(resolver::getTransactionHandleClass); distributedProcedureHandle = getHandleClass(resolver::getDistributedProcedureHandleClass); + tableFunctionHandle = getHandleClass(resolver::getTableFunctionHandleClass); } private static Optional> getHandleClass(Supplier> callable) @@ -354,6 +458,11 @@ public Optional> getTransactionHandl return transactionHandle; } + public Optional> getTableFunctionHandleClass() + { + return tableFunctionHandle; + } + @Override public boolean equals(Object o) { @@ -374,13 +483,14 @@ public boolean equals(Object o) Objects.equals(deleteTableHandle, that.deleteTableHandle) && Objects.equals(mergeTableHandle, that.mergeTableHandle) && Objects.equals(partitioningHandle, that.partitioningHandle) && - Objects.equals(transactionHandle, that.transactionHandle); + Objects.equals(transactionHandle, that.transactionHandle) && + Objects.equals(tableFunctionHandle, that.tableFunctionHandle); } @Override public int hashCode() { - return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, deleteTableHandle, mergeTableHandle, partitioningHandle, transactionHandle); + return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, deleteTableHandle, mergeTableHandle, partitioningHandle, transactionHandle, tableFunctionHandle); } } @@ -427,4 +537,48 @@ public int hashCode() return Objects.hash(functionHandle); } } + + private static class MaterializedResolver + { + private final Set> classes; + + public MaterializedResolver(Supplier>> classSupplier) + { + this.classes = getSafe(classSupplier); + } + + private static Set> getSafe(Supplier>> classSupplier) + { + try { + return classSupplier.get(); + } + catch (UnsupportedOperationException e) { + return ImmutableSet.of(); + } + } + + public Set> getClasses() + { + return classes; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MaterializedResolver that = (MaterializedResolver) o; + return Objects.equals(classes, that.classes); + } + + @Override + public int hashCode() + { + return Objects.hash(classes); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java index 9dbf15ec0adc4..b5bfedf3fde51 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -605,5 +605,10 @@ default boolean isPushdownSupportedForFilter(Session session, TableHandle tableH String normalizeIdentifier(Session session, String catalogName, String identifier); + /** + * Attempt to push down the table function invocation into the connector. + * @return {@link Optional#empty()} if the connector doesn't support table function invocation pushdown, + * or an {@code Optional>} containing the table handle that will be used in place of the table function invocation. + */ Optional> applyTableFunction(Session session, TableFunctionHandle handle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java new file mode 100644 index 0000000000000..dcc8a7aaf8d67 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; + +import java.util.Optional; +import java.util.function.Function; + +public class TableFunctionJacksonHandleModule + extends AbstractTypedJacksonModule +{ + @Inject + public TableFunctionJacksonHandleModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) + { + super(ConnectorTableFunctionHandle.class, + handleResolver::getId, + handleResolver::getTableFunctionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableFunctionHandleCodec)); + } + + public TableFunctionJacksonHandleModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorTableFunctionHandle.class, + handleResolver::getId, + handleResolver::getTableFunctionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java index da7e23f9bbf2b..7b0190936e0c9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java @@ -21,8 +21,9 @@ import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.spi.function.table.ArgumentSpecification; import com.facebook.presto.spi.function.table.ConnectorTableFunction; -import com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import com.facebook.presto.spi.function.table.DescribedTableReturnTypeSpecification; import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionMetadata; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; @@ -33,15 +34,17 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; import static com.facebook.presto.spi.StandardErrorCode.SESSION_CATALOG_NOT_SET; import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; @@ -50,6 +53,7 @@ public class TableFunctionRegistry { // catalog name in the original case; schema and function name in lowercase private final Map> tableFunctions = new ConcurrentHashMap<>(); + private final AtomicBoolean tvfProviderFunctionsLoaded = new AtomicBoolean(false); public void addTableFunctions(ConnectorId catalogName, Collection functions) { @@ -76,6 +80,16 @@ public void removeTableFunctions(ConnectorId catalogName) tableFunctions.remove(catalogName); } + public boolean areTvfProviderFunctionsLoaded() + { + return tvfProviderFunctionsLoaded.get(); + } + + public void updateTvfProviderFunctionsLoaded() + { + tvfProviderFunctionsLoaded.compareAndSet(false, true); + } + public static List toPath(Session session, QualifiedName name) { List parts = name.getParts(); @@ -101,6 +115,9 @@ public static List toPath(Session session, QualifiedN // add resolved path items names.add(new CatalogSchemaFunctionName(currentCatalog, currentSchema, parts.get(0))); + + // add builtin path items + names.add(new CatalogSchemaFunctionName("system", "builtin", parts.get(0))); return names.build(); } @@ -108,7 +125,7 @@ public static List toPath(Session session, QualifiedN * Resolve table function with given qualified name. * Table functions are resolved case-insensitive for consistency with existing scalar function resolution. */ - public Optional resolve(Session session, QualifiedName qualifiedName) + public TableFunctionMetadata resolve(Session session, QualifiedName qualifiedName) { for (CatalogSchemaFunctionName name : toPath(session, qualifiedName)) { ConnectorId connectorId = new ConnectorId(name.getCatalogName()); @@ -118,12 +135,27 @@ public Optional resolve(Session session, QualifiedName qu String lowercasedFunctionName = name.getSchemaFunctionName().getFunctionName().toLowerCase(ENGLISH); TableFunctionMetadata function = catalogFunctions.get(new SchemaFunctionName(lowercasedSchemaName, lowercasedFunctionName)); if (function != null) { - return Optional.of(function); + return function; } } } - return Optional.empty(); + return null; + } + + public TableFunctionMetadata resolve(ConnectorId connectorId, CatalogSchemaFunctionName name) + { + Map catalogFunctions = tableFunctions.get(connectorId); + if (catalogFunctions != null) { + String lowercasedSchemaName = name.getSchemaFunctionName().getSchemaName().toLowerCase(ENGLISH); + String lowercasedFunctionName = name.getSchemaFunctionName().getFunctionName().toLowerCase(ENGLISH); + TableFunctionMetadata function = catalogFunctions.get(new SchemaFunctionName(lowercasedSchemaName, lowercasedFunctionName)); + if (function != null) { + return function; + } + } + + throw new PrestoException(GENERIC_USER_ERROR, format("Table functions for catalog %s could not be resolved.", connectorId.getCatalogName())); } private static void validateTableFunction(ConnectorTableFunction tableFunction) @@ -154,8 +186,8 @@ private static void validateTableFunction(ConnectorTableFunction tableFunction) // Such a table argument is implicitly 'prune when empty'. The TableArgumentSpecification.Builder enforces the 'prune when empty' property // for a table argument with row semantics. - if (tableFunction.getReturnTypeSpecification() instanceof DescribedTable) { - DescribedTable describedTable = (DescribedTable) tableFunction.getReturnTypeSpecification(); + if (tableFunction.getReturnTypeSpecification() instanceof DescribedTableReturnTypeSpecification) { + DescribedTableReturnTypeSpecification describedTable = (DescribedTableReturnTypeSpecification) tableFunction.getReturnTypeSpecification(); checkArgument(describedTable.getDescriptor().isTyped(), "field types missing in returned type specification"); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java new file mode 100644 index 0000000000000..bda83ae6319d4 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; + +import java.util.List; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * This is a class representing empty input to a table function. An EmptyTableFunctionPartition is created + * when the table function has KEEP WHEN EMPTY property, which means that the function should be executed + * even if the input is empty, and all the table arguments are empty relations. + *

+ * An EmptyTableFunctionPartition is created and processed once per node. To avoid duplicated execution, + * a table function having KEEP WHEN EMPTY property must have single distribution. + */ +public class EmptyTableFunctionPartition + implements TableFunctionPartition +{ + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + private final Type[] passThroughTypes; + + public EmptyTableFunctionPartition(TableFunctionDataProcessor tableFunction, int properChannelsCount, int passThroughSourcesCount, List passThroughTypes) + { + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.passThroughTypes = passThroughTypes.toArray(new Type[] {}); + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(() -> { + TableFunctionProcessorState state = tableFunction.process(null); + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendNullsForPassThroughColumns(processed.getResult())); + } + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + }); + } + + private Page appendNullsForPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + + Block[] resultBlocks = new Block[properChannelsCount + passThroughTypes.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + // because no input was processed, all pass-through indexes in the result page must be null (there are no input rows they could refer to). + // for performance reasons this is not checked. All pass-through columns are filled with nulls. + int channel = properChannelsCount; + for (Type type : passThroughTypes) { + resultBlocks[channel] = RunLengthEncodedBlock.create(type, null, page.getPositionCount()); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java new file mode 100644 index 0000000000000..3eb272cef09ec --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.execution.ScheduledSplit; +import com.facebook.presto.metadata.Split; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.UpdatablePageSource; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Blocked; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class LeafTableFunctionOperator + implements SourceOperator +{ + public static class LeafTableFunctionOperatorFactory + implements SourceOperatorFactory + { + private final int operatorId; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + private boolean closed; + + public LeafTableFunctionOperatorFactory(int operatorId, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorId = operatorId; + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public SourceOperator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, LeafTableFunctionOperator.class.getSimpleName()); + return new LeafTableFunctionOperator(operatorContext, sourceId, tableFunctionProvider, functionHandle); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + } + + private final OperatorContext operatorContext; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + + private ConnectorSplit currentSplit; + private final List pendingSplits = new ArrayList<>(); + private boolean noMoreSplits; + + private TableFunctionSplitProcessor processor; + private boolean processorUsedData; + private boolean processorFinishedSplit = true; + private ListenableFuture processorBlocked = NOT_BLOCKED; + + public LeafTableFunctionOperator(OperatorContext operatorContext, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + private void resetProcessor() + { + this.processor = tableFunctionProvider.getSplitProcessor(functionHandle); + this.processorUsedData = false; + this.processorFinishedSplit = false; + this.processorBlocked = NOT_BLOCKED; + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public boolean needsInput() + { + return false; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(getClass().getName() + " does not take input"); + } + + @Override + public Supplier> addSplit(ScheduledSplit split) + { + Split curSplit = requireNonNull(split, "split is null").getSplit(); + checkState(!noMoreSplits, "no more splits expected"); + ConnectorSplit curConnectorSplit = curSplit.getConnectorSplit(); + pendingSplits.add(curConnectorSplit); + return Optional::empty; + } + + @Override + public void noMoreSplits() + { + noMoreSplits = true; + } + + @Override + public Page getOutput() + { + if (processorFinishedSplit) { + // start processing a new split + if (pendingSplits.isEmpty()) { + // no more splits to process at the moment + return null; + } + currentSplit = pendingSplits.remove(0); + resetProcessor(); + } + else { + // a split is being processed + requireNonNull(currentSplit, "currentSplit is null"); + } + + TableFunctionProcessorState state = processor.process(processorUsedData ? null : currentSplit); + if (state == FINISHED) { + processorFinishedSplit = true; + } + if (state instanceof Blocked) { + Blocked blocked = (Blocked) state; + processorBlocked = toListenableFuture(blocked.getFuture()); + } + if (state instanceof Processed) { + Processed processed = (Processed) state; + if (processed.isUsedInput()) { + processorUsedData = true; + } + if (processed.getResult() != null) { + return processed.getResult(); + } + } + return null; + } + + @Override + public ListenableFuture isBlocked() + { + return processorBlocked; + } + + @Override + public void finish() + { + // this method is redundant. the operator takes no input at all. noMoreSplits() should be called instead. + } + + @Override + public boolean isFinished() + { + return processorFinishedSplit && pendingSplits.isEmpty() && noMoreSplits; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java new file mode 100644 index 0000000000000..cb14500597a1b --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.operator.WorkProcessor.ProcessState; +import jakarta.annotation.Nullable; + +import static com.facebook.presto.operator.WorkProcessor.ProcessState.finished; +import static com.facebook.presto.operator.WorkProcessor.ProcessState.ofResult; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class PageBuffer +{ + @Nullable + private Page page; + private boolean finished; + + public WorkProcessor pages() + { + return WorkProcessor.create(() -> { + if (isFinished() && isEmpty()) { + return finished(); + } + + if (!isEmpty()) { + Page result = page; + page = null; + return ofResult(result); + } + + return ProcessState.yield(); + }); + } + + public boolean isEmpty() + { + return page == null; + } + + public boolean isFinished() + { + return finished; + } + + public void add(Page page) + { + checkState(isEmpty(), "page buffer is not empty"); + checkState(!isFinished(), "page buffer is finished"); + requireNonNull(page, "page is null"); + + if (page.getPositionCount() == 0) { + return; + } + + this.page = page; + } + + public void finish() + { + finished = true; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java index 640ae9919ca90..f4da55ac314e0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -270,9 +270,9 @@ public void swap(int a, int b) valueAddresses.swap(a, b); } - public int buildPage(int position, int[] outputChannels, PageBuilder pageBuilder) + public int buildPage(int position, int endPosition, int[] outputChannels, PageBuilder pageBuilder) { - while (!pageBuilder.isFull() && position < positionCount) { + while (!pageBuilder.isFull() && position < endPosition) { long pageAddress = valueAddresses.get(position); int blockIndex = decodeSliceIndex(pageAddress); int blockPosition = decodePosition(pageAddress); @@ -562,10 +562,29 @@ protected Page computeNext() } public Iterator getSortedPages() + { + return getSortedPagesFromRange(0, positionCount); + } + + /** + * Get sorted pages from the specified section of the PagesIndex. + * + * @param start start position of the section, inclusive + * @param end end position of the section, exclusive + * @return iterator of pages + */ + public Iterator getSortedPages(int start, int end) + { + checkArgument(start >= 0 && end <= positionCount, "position range out of bounds"); + checkArgument(start <= end, "invalid position range"); + return getSortedPagesFromRange(start, end); + } + + private Iterator getSortedPagesFromRange(int start, int end) { return new AbstractIterator() { - private int currentPosition; + private int currentPosition = start; private final PageBuilder pageBuilder = new PageBuilder(types); private final int[] outputChannels = new int[types.size()]; @@ -576,7 +595,7 @@ public Iterator getSortedPages() @Override public Page computeNext() { - currentPosition = buildPage(currentPosition, outputChannels, pageBuilder); + currentPosition = buildPage(currentPosition, end, outputChannels, pageBuilder); if (pageBuilder.isEmpty()) { return endOfData(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java new file mode 100644 index 0000000000000..5d0376f3be7ee --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java @@ -0,0 +1,438 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Ints; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.common.Utils.checkState; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class RegularTableFunctionPartition + implements TableFunctionPartition +{ + private final PagesIndex pagesIndex; + private final int partitionStart; + private final int partitionEnd; + private final Iterator sortedPages; + + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + + // channels required by the table function, listed by source in order of argument declarations + private final int[][] requiredChannels; + + // for each input channel, the end position of actual data in that channel (exclusive) relative to partition. The remaining rows are "filler" rows, and should not be passed to table function or passed-through + private final int[] endOfData; + + // a builder for each pass-through column, in order of argument declarations + private final PassThroughColumnProvider[] passThroughProviders; + + // number of processed input positions from partition start. all sources have been processed up to this position, except the sources whose partitions ended earlier. + private int processedPositions; + + public RegularTableFunctionPartition( + PagesIndex pagesIndex, + int partitionStart, + int partitionEnd, + TableFunctionDataProcessor tableFunction, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications) + + { + checkArgument(pagesIndex.getPositionCount() != 0, "PagesIndex is empty for regular table function partition"); + this.pagesIndex = pagesIndex; + this.partitionStart = partitionStart; + this.partitionEnd = partitionEnd; + this.sortedPages = pagesIndex.getSortedPages(partitionStart, partitionEnd); + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(Ints::toArray) + .toArray(int[][]::new); + this.endOfData = findEndOfData(markerChannels, requiredChannels, passThroughSpecifications); + for (List channels : requiredChannels) { + checkState( + channels.stream() + .mapToInt(channel -> endOfData[channel]) + .distinct() + .count() <= 1, + "end-of-data position is inconsistent within a table function source"); + } + this.passThroughProviders = new PassThroughColumnProvider[passThroughSpecifications.size()]; + for (int i = 0; i < passThroughSpecifications.size(); i++) { + passThroughProviders[i] = createColumnProvider(passThroughSpecifications.get(i)); + } + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(new WorkProcessor.Process() + { + List> inputPages = prepareInputPages(); + + @Override + public WorkProcessor.ProcessState process() + { + TableFunctionProcessorState state = tableFunction.process(inputPages); + boolean functionGotNoData = inputPages == null; + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.isUsedInput()) { + inputPages = prepareInputPages(); + } + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendPassThroughColumns(processed.getResult())); + } + if (functionGotNoData) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + } + return WorkProcessor.ProcessState.blocked(immediateFuture(null)); + } + }); + } + + /** + * Iterate over the partition by page and extract pages for each table function source from the input page. + * For each source, project the columns required by the table function. + * If for some source all data in the partition has been consumed, Optional.empty() is returned for that source. + * It happens when the partition of this source is shorter than the partition of some other source. + * The overall length of the table function partition is equal to the length of the longest source partition. + * When all sources are fully consumed, this method returns null. + *

+ * NOTE: There are two types of table function's source semantics: set and row. The two types of sources should be handled + * by the TableFunctionDataProcessor in different ways. For a source with set semantics, the whole partition can be used for computations, + * while for a source with row semantics, each row should be processed independently from all other rows. + * To enforce that behavior, we could pass to the TableFunctionDataProcessor only one row from a table with row semantics. + * However, for performance reasons, we handle sources with row and set semantics in the same way: the TableFunctionDataProcessor + * gets a page of data from each source. The TableFunctionDataProcessor is responsible for using the provided data accordingly + * to the declared source semantics (set or rows). + * + * @return A List containing: + * - Optional Page for every source that is not fully consumed + * - Optional.empty() for every source that is fully consumed + * or null if all sources are fully consumed. + */ + private List> prepareInputPages() + { + if (!sortedPages.hasNext()) { + return null; + } + + Page inputPage = sortedPages.next(); + ImmutableList.Builder> sourcePages = ImmutableList.builder(); + + for (int[] channelsForSource : requiredChannels) { + if (channelsForSource.length == 0) { + sourcePages.add(Optional.of(new Page(inputPage.getPositionCount()))); + } + else { + int endOfDataForSource = endOfData[channelsForSource[0]]; // end-of-data position is validated to be consistent for all channels from source + if (endOfDataForSource <= processedPositions) { + // all data for this source was already processed + sourcePages.add(Optional.empty()); + } + else { + Block[] sourceBlocks = new Block[channelsForSource.length]; + if (endOfDataForSource < processedPositions + inputPage.getPositionCount()) { + // data for this source ends within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel).getRegion(0, endOfDataForSource - processedPositions); + } + } + else { + // data for this source does not end within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel); + } + } + sourcePages.add(Optional.of(new Page(sourceBlocks))); + } + } + } + + processedPositions += inputPage.getPositionCount(); + + return sourcePages.build(); + } + + /** + * There are two types of table function's source semantics: set and row. + *

+ * For a source with row semantics, the table function result depends on the whole partition, + * so it is not always possible to associate an output row with a specific input row. + * The TableFunctionDataProcessor can return null as the pass-through index to indicate that + * the output row is not associated with any row from the given source. + *

+ * For a source with row semantics, the output is determined on a row-by-row basis, so every + * output row is associated with a specific input row. In such case, the pass-through index + * should never be null. + *

+ * In our implementation, we handle sources with row and set semantics in the same way. + * For performance reasons, we do not validate the null pass-through indexes. + * The TableFunctionDataProcessor is responsible for using the pass-through capability + * accordingly to the declared source semantics (set or rows). + */ + private Page appendPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + // TODO is it possible to verify types of columns returned by TF? + + Block[] resultBlocks = new Block[properChannelsCount + passThroughProviders.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + int channel = properChannelsCount; + for (PassThroughColumnProvider provider : passThroughProviders) { + resultBlocks[channel] = provider.getPassThroughColumn(page); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } + + private int[] findEndOfData(Optional> markerChannels, List> requiredChannels, List passThroughSpecifications) + { + Set referencedChannels = ImmutableSet.builder() + .addAll(requiredChannels.stream() + .flatMap(Collection::stream) + .collect(toImmutableList())) + .addAll(passThroughSpecifications.stream() + .map(PassThroughColumnSpecification::getInputChannel) + .collect(toImmutableList())) + .build(); + + if (referencedChannels.isEmpty()) { + // no required or pass-through channels + return null; + } + + int maxInputChannel = referencedChannels.stream() + .mapToInt(Integer::intValue) + .max() + .orElseThrow(NoSuchElementException::new); + + int[] result = new int[maxInputChannel + 1]; + Arrays.fill(result, -1); + + // if table function had one source, adding a marker channel was not necessary. + // end-of-data position is equal to partition end for each input channel + if (!markerChannels.isPresent()) { + referencedChannels.stream() + .forEach(channel -> result[channel] = partitionEnd - partitionStart); + return result; + } + + // if table function had more than one source, the markers map shall be present, and it shall contain mapping for each input channel + ImmutableMap.Builder endOfDataPerMarkerBuilder = ImmutableMap.builder(); + for (int markerChannel : ImmutableSet.copyOf(markerChannels.orElseThrow(NoSuchElementException::new).values())) { + endOfDataPerMarkerBuilder.put(markerChannel, findFirstNullPosition(markerChannel)); + } + Map endOfDataPerMarker = endOfDataPerMarkerBuilder.buildOrThrow(); + referencedChannels.stream() + .forEach(channel -> result[channel] = endOfDataPerMarker.get(markerChannels.orElseThrow(NoSuchElementException::new).get(channel)) - partitionStart); + + return result; + } + + private int findFirstNullPosition(int markerChannel) + { + if (pagesIndex.isNull(markerChannel, partitionStart)) { + return partitionStart; + } + if (!pagesIndex.isNull(markerChannel, partitionEnd - 1)) { + return partitionEnd; + } + + int start = partitionStart; + int end = partitionEnd; + // value at start is not null, value at end is null + while (end - start > 1) { + int mid = (start + end) >>> 1; + if (pagesIndex.isNull(markerChannel, mid)) { + end = mid; + } + else { + start = mid; + } + } + return end; + } + + public static class PassThroughColumnSpecification + { + private final boolean isPartitioningColumn; + private final int inputChannel; + private final int indexChannel; + + public PassThroughColumnSpecification(boolean isPartitioningColumn, int inputChannel, int indexChannel) + { + this.isPartitioningColumn = isPartitioningColumn; + this.inputChannel = inputChannel; + this.indexChannel = indexChannel; + } + + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + + public int getInputChannel() + { + return inputChannel; + } + + public int getIndexChannel() + { + return indexChannel; + } + } + + private PassThroughColumnProvider createColumnProvider(PassThroughColumnSpecification specification) + { + if (specification.isPartitioningColumn()) { + return new PartitioningColumnProvider(pagesIndex.getSingleValueBlock(specification.getInputChannel(), partitionStart)); + } + return new NonPartitioningColumnProvider(specification.getInputChannel(), specification.getIndexChannel()); + } + + private interface PassThroughColumnProvider + { + Block getPassThroughColumn(Page page); + } + + private static class PartitioningColumnProvider + implements PassThroughColumnProvider + { + private final Block partitioningValue; + + private PartitioningColumnProvider(Block partitioningValue) + { + this.partitioningValue = requireNonNull(partitioningValue, "partitioningValue is null"); + } + + @Override + public Block getPassThroughColumn(Page page) + { + return new RunLengthEncodedBlock(partitioningValue, page.getPositionCount()); + } + + public Block getPartitioningValue() + { + return partitioningValue; + } + } + + private final class NonPartitioningColumnProvider + implements PassThroughColumnProvider + { + private final int inputChannel; + private final Type type; + private final int indexChannel; + + public NonPartitioningColumnProvider(int inputChannel, int indexChannel) + { + this.inputChannel = inputChannel; + this.type = pagesIndex.getType(inputChannel); + this.indexChannel = indexChannel; + } + + @Override + public Block getPassThroughColumn(Page page) + { + Block indexes = page.getBlock(indexChannel); + BlockBuilder builder = type.createBlockBuilder(null, page.getPositionCount()); + for (int position = 0; position < page.getPositionCount(); position++) { + if (indexes.isNull(position)) { + builder.appendNull(); + } + else { + // table function returns index from partition start + long index = BIGINT.getLong(indexes, position); + // validate index + if (index < 0 || index >= endOfData[inputChannel] || index >= processedPositions) { + int end = min(endOfData[inputChannel], processedPositions) - 1; + if (end >= 0) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, format("Index of a pass-through row: %s out of processed portion of partition [0, %s]", index, end)); + } + else { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "Index of a pass-through row must be null when no input data from the partition was processed. Actual: " + index); + } + } + // index in PagesIndex + long absoluteIndex = partitionStart + index; + pagesIndex.appendTo(inputChannel, toIntExact(absoluteIndex), builder); + } + } + + return builder.build(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java new file mode 100644 index 0000000000000..97d71899d5b9a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java @@ -0,0 +1,642 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; +import jakarta.annotation.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkPositionIndex; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.concat; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; + +public class TableFunctionOperator + implements Operator +{ + public static class TableFunctionOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + + // a provider of table function processor to be called once per partition + private final TableFunctionProcessorProvider tableFunctionProvider; + + // all information necessary to execute the table function collected during analysis + private final ConnectorTableFunctionHandle functionHandle; + + // number of proper columns produced by the table function + private final int properChannelsCount; + + // number of input tables declared as pass-through + private final int passThroughSourcesCount; + + // columns required by the table function, in order of input tables + private final List> requiredChannels; + + // map from input channel to marker channel + // for each input table, there is a channel that marks which rows contain original data, and which are "filler" rows. + // the "filler" rows are part of the algorithm, and they should not be processed by the table function, or passed-through. + // In this map, every original column from the input table is associated with the appropriate marker. + private final Optional> markerChannels; + + // necessary information to build a pass-through column, for all pass-through columns, ordered as expected on the output + // it includes columns from sources declared as pass-through as well as partitioning columns from other sources + private final List passThroughSpecifications; + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // partitioning channels from all sources + private final List partitionChannels; + + // subset of partition channels that are already grouped + private final List prePartitionedChannels; + + // channels necessary to sort all sources: + // - for a single source, these are the source's sort channels + // - for multiple sources, this is a single synthesized row number channel + private final List sortChannels; + private final List sortOrders; + + // number of leading sort channels that are already sorted + private final int preSortedPrefix; + + private final List sourceTypes; + private final int expectedPositions; + private final PagesIndex.Factory pagesIndexFactory; + + private boolean closed; + + public TableFunctionOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(planNodeId, "planNodeId is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorId = operatorId; + this.planNodeId = planNodeId; + this.tableFunctionProvider = tableFunctionProvider; + this.functionHandle = functionHandle; + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerChannels = markerChannels.map(ImmutableMap::copyOf); + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.pruneWhenEmpty = pruneWhenEmpty; + this.partitionChannels = ImmutableList.copyOf(partitionChannels); + this.prePartitionedChannels = ImmutableList.copyOf(prePartitionedChannels); + this.sortChannels = ImmutableList.copyOf(sortChannels); + this.sortOrders = ImmutableList.copyOf(sortOrders); + this.preSortedPrefix = preSortedPrefix; + this.sourceTypes = ImmutableList.copyOf(sourceTypes); + this.expectedPositions = expectedPositions; + this.pagesIndexFactory = pagesIndexFactory; + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TableFunctionOperator.class.getSimpleName()); + return new TableFunctionOperator( + operatorContext, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new TableFunctionOperatorFactory( + operatorId, + planNodeId, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + } + + private final OperatorContext operatorContext; + + private final PageBuffer pageBuffer = new PageBuffer(); + private final WorkProcessor outputPages; + private final boolean processEmptyInput; + + @Nullable + private Page pendingInput; + private boolean operatorFinishing; + + public TableFunctionOperator( + OperatorContext operatorContext, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(operatorContext, "operatorContext is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorContext = operatorContext; + + this.processEmptyInput = !pruneWhenEmpty; + + PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); + HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels, prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix); + + this.outputPages = WorkProcessor.create(new PagesSource()) + .transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput)) + .flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions( + groupPagesIndex, + hashStrategies, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + processEmptyInput)) + .flatMap(TableFunctionPartition::toOutputPages); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public void finish() + { + operatorFinishing = true; + } + + @Override + public boolean isFinished() + { + return outputPages.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + if (outputPages.isBlocked()) { + return outputPages.getBlockedFuture(); + } + + return NOT_BLOCKED; + } + + @Override + public boolean needsInput() + { + return pendingInput == null && !operatorFinishing; + } + + @Override + public void addInput(Page page) + { + requireNonNull(page, "page is null"); + checkState(pendingInput == null, "Operator already has pending input"); + + if (page.getPositionCount() == 0) { + return; + } + + pendingInput = page; + } + + @Override + public Page getOutput() + { + if (!outputPages.process()) { + return null; + } + + if (outputPages.isFinished()) { + return null; + } + + return outputPages.getResult(); + } + + private static class HashStrategies + { + final PagesHashStrategy prePartitionedStrategy; + final PagesHashStrategy remainingPartitionStrategy; + final PagesHashStrategy preSortedStrategy; + final List remainingPartitionAndSortChannels; + final List remainingSortOrders; + final int[] prePartitionedChannelsArray; + + public HashStrategies( + PagesIndex pagesIndex, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix) + { + this.prePartitionedStrategy = pagesIndex.createPagesHashStrategy(prePartitionedChannels, OptionalInt.empty()); + + List remainingPartitionChannels = partitionChannels.stream() + .filter(channel -> !prePartitionedChannels.contains(channel)) + .collect(toImmutableList()); + this.remainingPartitionStrategy = pagesIndex.createPagesHashStrategy(remainingPartitionChannels, OptionalInt.empty()); + + List preSortedChannels = sortChannels.stream() + .limit(preSortedPrefix) + .collect(toImmutableList()); + this.preSortedStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); + + if (preSortedPrefix > 0) { + // preSortedPrefix > 0 implies that all partition channels are already pre-partitioned (enforced by check in the constructor), so we only need to do the remaining sort + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedPrefix)); + this.remainingSortOrders = ImmutableList.copyOf(Iterables.skip(sortOrders, preSortedPrefix)); + } + else { + // we need to sort by the remaining partition channels so that the input is fully partitioned, + // and then need to we sort by all the sort channels so that the input is fully sorted + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(concat(remainingPartitionChannels, sortChannels)); + this.remainingSortOrders = ImmutableList.copyOf(concat(nCopies(remainingPartitionChannels.size(), ASC_NULLS_LAST), sortOrders)); + } + + this.prePartitionedChannelsArray = Ints.toArray(prePartitionedChannels); + } + } + + private class PartitionAndSort + implements WorkProcessor.Transformation + { + private final PagesIndex pagesIndex; + private final HashStrategies hashStrategies; + private final LocalMemoryContext memoryContext; + + private boolean resetPagesIndex; + private int inputPosition; + private boolean processEmptyInput; + + public PartitionAndSort(PagesIndex pagesIndex, HashStrategies hashStrategies, boolean processEmptyInput) + { + this.pagesIndex = pagesIndex; + this.hashStrategies = hashStrategies; + this.memoryContext = operatorContext.aggregateUserMemoryContext().newLocalMemoryContext(PartitionAndSort.class.getSimpleName()); + this.processEmptyInput = processEmptyInput; + } + + @Override + public WorkProcessor.TransformationState process(Optional input) + { + if (resetPagesIndex) { + pagesIndex.clear(); + updateMemoryUsage(); + resetPagesIndex = false; + } + + if (!input.isPresent() && pagesIndex.getPositionCount() == 0) { + if (processEmptyInput) { + // it can only happen at the first call to process(), which implies that there is no input. Empty PagesIndex can be passed on only once. + processEmptyInput = false; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + else { + memoryContext.close(); + return WorkProcessor.TransformationState.finished(); + } + } + + // there is input, so we are not interested in processing empty input + processEmptyInput = false; + + if (input.isPresent()) { + // append rows from input which belong to the current group wrt pre-partitioned columns + // it might be one or more partitions + inputPosition = appendCurrentGroup(pagesIndex, hashStrategies, input.get(), inputPosition); + updateMemoryUsage(); + + if (inputPosition >= input.get().getPositionCount()) { + inputPosition = 0; + return WorkProcessor.TransformationState.needsMoreData(); + } + } + + // we have unused input or the input is finished. we have buffered a full group + // the group contains one or more partitions, as it was determined by the pre-partitioned columns + // sorting serves two purposes: + // - sort by the remaining partition channels so that the input is fully partitioned, + // - sort by all the sort channels so that the input is fully sorted + sortCurrentGroup(pagesIndex, hashStrategies); + resetPagesIndex = true; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + + void updateMemoryUsage() + { + memoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); + } + } + + private static int appendCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies, Page page, int startPosition) + { + checkArgument(page.getPositionCount() > startPosition); + + PagesHashStrategy prePartitionedStrategy = hashStrategies.prePartitionedStrategy; + Page prePartitionedPage = page.extractChannels(hashStrategies.prePartitionedChannelsArray); + + if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(prePartitionedStrategy, 0, startPosition, prePartitionedPage)) { + // we are within the current group. find the position where the pre-grouped columns change + int groupEnd = findGroupEnd(prePartitionedPage, prePartitionedStrategy, startPosition); + + // add the section of the page that contains values for the current group + pagesIndex.addPage(page.getRegion(startPosition, groupEnd - startPosition)); + + if (page.getPositionCount() - groupEnd > 0) { + // the remaining prt of the page contains the next group + return groupEnd; + } + // page fully consumed: it contains the current group only + return page.getPositionCount(); + } + + // we had previous results buffered, but the remaining page starts with new group values + return startPosition; + } + + private static void sortCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies) + { + PagesHashStrategy preSortedStrategy = hashStrategies.preSortedStrategy; + List remainingPartitionAndSortChannels = hashStrategies.remainingPartitionAndSortChannels; + List remainingSortOrders = hashStrategies.remainingSortOrders; + + if (pagesIndex.getPositionCount() > 1 && !remainingPartitionAndSortChannels.isEmpty()) { + int startPosition = 0; + while (startPosition < pagesIndex.getPositionCount()) { + int endPosition = findGroupEnd(pagesIndex, preSortedStrategy, startPosition); + pagesIndex.sort(remainingPartitionAndSortChannels, remainingSortOrders, startPosition, endPosition); + startPosition = endPosition; + } + } + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(page.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page)); + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition)); + } + + /** + * @param startPosition - inclusive + * @param endPosition - exclusive + * @param comparator - returns true if positions given as parameters are equal + * @return the end of the group position exclusive (the position the very next group starts) + */ + @VisibleForTesting + static int findEndPosition(int startPosition, int endPosition, PositionComparator comparator) + { + checkArgument(startPosition >= 0, "startPosition must be greater or equal than zero: %s", startPosition); + checkArgument(startPosition < endPosition, "startPosition (%s) must be less than endPosition (%s)", startPosition, endPosition); + + int left = startPosition; + int right = endPosition; + + while (right - left > 1) { + int middle = (left + right) >>> 1; + + if (comparator.test(startPosition, middle)) { + left = middle; + } + else { + right = middle; + } + } + + return right; + } + + private interface PositionComparator + { + boolean test(int first, int second); + } + + private WorkProcessor pagesIndexToTableFunctionPartitions( + PagesIndex pagesIndex, + HashStrategies hashStrategies, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean processEmptyInput) + { + // pagesIndex contains the full grouped and sorted data for one or more partitions + + PagesHashStrategy remainingPartitionStrategy = hashStrategies.remainingPartitionStrategy; + + return WorkProcessor.create(new WorkProcessor.Process() + { + private int partitionStart; + private boolean processEmpty = processEmptyInput; + + @Override + public WorkProcessor.ProcessState process() + { + if (partitionStart == pagesIndex.getPositionCount()) { + if (processEmpty && pagesIndex.getPositionCount() == 0) { + // empty PagesIndex can only be passed once as the result of PartitionAndSort. Neither this nor any future instance of Process will ever get an empty PagesIndex again. + processEmpty = false; + return WorkProcessor.ProcessState.ofResult(new EmptyTableFunctionPartition( + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + passThroughSpecifications.stream() + .map(RegularTableFunctionPartition.PassThroughColumnSpecification::getInputChannel) + .map(pagesIndex::getType) + .collect(toImmutableList()))); + } + return WorkProcessor.ProcessState.finished(); + } + + // there is input, so we are not interested in processing empty input + processEmpty = false; + + int partitionEnd = findGroupEnd(pagesIndex, remainingPartitionStrategy, partitionStart); + + RegularTableFunctionPartition partition = new RegularTableFunctionPartition( + pagesIndex, + partitionStart, + partitionEnd, + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications); + + partitionStart = partitionEnd; + return WorkProcessor.ProcessState.ofResult(partition); + } + }); + } + + private class PagesSource + implements WorkProcessor.Process + { + @Override + public WorkProcessor.ProcessState process() + { + if (operatorFinishing && pendingInput == null) { + return WorkProcessor.ProcessState.finished(); + } + + if (pendingInput != null) { + Page result = pendingInput; + pendingInput = null; + return WorkProcessor.ProcessState.ofResult(result); + } + + return WorkProcessor.ProcessState.yield(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java new file mode 100644 index 0000000000000..1876b352bd251 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; + +public interface TableFunctionPartition +{ + WorkProcessor toOutputPages(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java index 736662d471eaf..744b487e670d4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java @@ -43,7 +43,6 @@ import static com.facebook.presto.common.block.DictionaryId.randomDictionaryId; import static com.facebook.presto.operator.WorkProcessor.ProcessState.finished; import static com.facebook.presto.operator.WorkProcessor.ProcessState.ofResult; -import static com.facebook.presto.operator.WorkProcessor.ProcessState.yield; import static com.facebook.presto.operator.project.SelectedPositions.positionsRange; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; @@ -204,7 +203,7 @@ public ProcessState process() lastComputeYielded = true; lastComputeBatchSize = batchSize; updateRetainedSize(); - return yield(); + return ProcessState.yield(); } if (result.isPageTooLarge()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java b/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java new file mode 100644 index 0000000000000..d7881b63c0bc3 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.table; + +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; + +import javax.inject.Provider; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; +import static com.facebook.presto.spi.function.table.GenericTableReturnTypeSpecification.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.stream.Collectors.joining; + +public class ExcludeColumns + implements Provider +{ + public static final String NAME = "exclude_columns"; + + @Override + public ConnectorTableFunction get() + { + return new ExcludeColumnsFunction(); + } + + public static class ExcludeColumnsFunction + extends AbstractConnectorTableFunction + { + private static final String TABLE_ARGUMENT_NAME = "INPUT"; + private static final String DESCRIPTOR_ARGUMENT_NAME = "COLUMNS"; + + public ExcludeColumnsFunction() + { + super( + "builtin", + NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name(TABLE_ARGUMENT_NAME) + .rowSemantics() + .build(), + DescriptorArgumentSpecification.builder() + .name(DESCRIPTOR_ARGUMENT_NAME) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + DescriptorArgument excludedColumns = (DescriptorArgument) arguments.get(DESCRIPTOR_ARGUMENT_NAME); + if (excludedColumns.equals(NULL_DESCRIPTOR)) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "COLUMNS descriptor is null"); + } + Descriptor excludedColumnsDescriptor = excludedColumns.getDescriptor().orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing exclude columns descriptor")); + if (excludedColumnsDescriptor.getFields().stream().anyMatch(field -> field.getType().isPresent())) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "COLUMNS descriptor contains types"); + } + + // column names in DescriptorArgument are canonical wrt SQL identifier semantics. + // column names in TableArgument are not canonical wrt SQL identifier semantics, as they are taken from the corresponding RelationType. + // because of that, we match the excluded columns names case-insensitive + // TODO: apply proper identifier semantics + Set excludedNames = excludedColumnsDescriptor.getFields().stream() + .map(Descriptor.Field::getName) + .map(name -> name.orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing Descriptor field name")).toLowerCase(ENGLISH)) + .collect(toImmutableSet()); + + List inputSchema = ((TableArgument) arguments.get(TABLE_ARGUMENT_NAME)).getRowType().getFields(); + Set inputNames = inputSchema.stream() + .map(RowType.Field::getName) + .filter(Optional::isPresent) + .map(Optional::get) + .map(name -> name.toLowerCase(ENGLISH)) + .collect(toImmutableSet()); + + if (!inputNames.containsAll(excludedNames)) { + String missingColumns = Sets.difference(excludedNames, inputNames).stream() + .collect(joining(", ", "[", "]")); + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Excluded columns: %s not present in the table", missingColumns)); + } + + ImmutableList.Builder requiredColumns = ImmutableList.builder(); + ImmutableList.Builder returnedColumns = ImmutableList.builder(); + + for (int i = 0; i < inputSchema.size(); i++) { + Optional name = inputSchema.get(i).getName(); + if (!name.isPresent() || !excludedNames.contains(name.orElseThrow(() -> new PrestoException(INVALID_FUNCTION_ARGUMENT, "Missing schema name")).toLowerCase(ENGLISH))) { + requiredColumns.add(i); + // per SQL standard, all columns produced by a table function must be named. We allow anonymous columns. + returnedColumns.add(new Descriptor.Field(name, Optional.of(inputSchema.get(i).getType()))); + } + } + + List returnedType = returnedColumns.build(); + if (returnedType.isEmpty()) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "All columns are excluded"); + } + + return TableFunctionAnalysis.builder() + .requiredColumns(TABLE_ARGUMENT_NAME, requiredColumns.build()) + .returnedType(new Descriptor(returnedType)) + .handle(new ExcludeColumnsFunctionHandle()) + .build(); + } + } + + public static TableFunctionProcessorProvider getExcludeColumnsFunctionProcessorProvider() + { + return new TableFunctionProcessorProvider() + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(getOnlyElement(input).orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing data processor input"))); + }; + } + }; + } + + public static class ExcludeColumnsFunctionHandle + implements ConnectorTableFunctionHandle + { + // there's no information to remember. All logic is effectively delegated to the engine via `requiredColumns`. + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java b/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java new file mode 100644 index 0000000000000..d992ff9993316 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java @@ -0,0 +1,332 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.table; + +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.DescribedTableReturnTypeSpecification; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.inject.Provider; + +import java.math.BigInteger; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.operator.table.Sequence.SequenceFunctionSplit.MAX_SPLIT_SIZE; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.Descriptor.descriptor; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; + +public class Sequence + implements Provider +{ + public static final String NAME = "sequence"; + + @Override + public ConnectorTableFunction get() + { + return new SequenceFunction(); + } + + public static class SequenceFunction + extends AbstractConnectorTableFunction + { + private static final String START_ARGUMENT_NAME = "START"; + private static final String STOP_ARGUMENT_NAME = "STOP"; + private static final String STEP_ARGUMENT_NAME = "STEP"; + + public SequenceFunction() + { + super( + "builtin", + NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name(START_ARGUMENT_NAME) + .type(BIGINT) + .defaultValue(0L) + .build(), + ScalarArgumentSpecification.builder() + .name(STOP_ARGUMENT_NAME) + .type(BIGINT) + .build(), + ScalarArgumentSpecification.builder() + .name(STEP_ARGUMENT_NAME) + .type(BIGINT) + .defaultValue(1L) + .build()), + new DescribedTableReturnTypeSpecification(descriptor(ImmutableList.of("sequential_number"), ImmutableList.of(BIGINT)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + Object startValue = ((ScalarArgument) arguments.get(START_ARGUMENT_NAME)).getValue(); + if (startValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Start is null"); + } + + Object stopValue = ((ScalarArgument) arguments.get(STOP_ARGUMENT_NAME)).getValue(); + if (stopValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Stop is null"); + } + + Object stepValue = ((ScalarArgument) arguments.get(STEP_ARGUMENT_NAME)).getValue(); + if (stepValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Step is null"); + } + + long start = (long) startValue; + long stop = (long) stopValue; + long step = (long) stepValue; + + if (start < stop && step <= 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Step must be positive for sequence [%s, %s]", start, stop)); + } + + if (start > stop && step >= 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Step must be negative for sequence [%s, %s]", start, stop)); + } + + return TableFunctionAnalysis.builder() + .handle(new SequenceFunctionHandle(start, stop, start == stop ? 0 : step)) + .build(); + } + } + + public static class SequenceFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long start; + private final long stop; + private final long step; + + @JsonCreator + public SequenceFunctionHandle(@JsonProperty("start") long start, @JsonProperty("stop") long stop, @JsonProperty("step") long step) + { + this.start = start; + this.stop = stop; + this.step = step; + } + + @JsonProperty + public long start() + { + return start; + } + + @JsonProperty + public long stop() + { + return stop; + } + + @JsonProperty + public long step() + { + return step; + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, NodeManager nodeManager, Object functionAndTypeManager) + { + return getSequenceFunctionSplitSource(this); + } + } + + public static ConnectorSplitSource getSequenceFunctionSplitSource(SequenceFunctionHandle handle) + { + // using BigInteger to avoid long overflow since it's not in the main data processing loop + BigInteger start = BigInteger.valueOf(handle.start()); + BigInteger stop = BigInteger.valueOf(handle.stop()); + BigInteger step = BigInteger.valueOf(handle.step()); + + if (step.equals(BigInteger.ZERO)) { + checkArgument(start.equals(stop), "start is not equal to stop for step = 0"); + return new FixedSplitSource(ImmutableList.of(new SequenceFunctionSplit(start.longValueExact(), stop.longValueExact()))); + } + + ImmutableList.Builder splits = ImmutableList.builder(); + + BigInteger totalSteps = stop.subtract(start).divide(step).add(BigInteger.ONE); + BigInteger totalSplits = totalSteps.divide(BigInteger.valueOf(MAX_SPLIT_SIZE)).add(BigInteger.ONE); + BigInteger[] stepsPerSplit = totalSteps.divideAndRemainder(totalSplits); + BigInteger splitJump = stepsPerSplit[0].subtract(BigInteger.ONE).multiply(step); + + BigInteger splitStart = start; + for (BigInteger i = BigInteger.ZERO; i.compareTo(totalSplits) < 0; i = i.add(BigInteger.ONE)) { + BigInteger splitStop = splitStart.add(splitJump); + // distribute the remaining steps between the initial splits, one step per split + if (i.compareTo(stepsPerSplit[1]) < 0) { + splitStop = splitStop.add(step); + } + splits.add(new SequenceFunctionSplit(splitStart.longValueExact(), splitStop.longValueExact())); + splitStart = splitStop.add(step); + } + + return new FixedSplitSource(splits.build()); + } + + public static class SequenceFunctionSplit + implements ConnectorSplit + { + public static final int DEFAULT_SPLIT_SIZE = 1000000; + public static final int MAX_SPLIT_SIZE = 1000000; + + // the first value of sub-sequence + private final long start; + + // the last value of sub-sequence. this value is aligned so that it belongs to the sequence. + private final long stop; + + @JsonCreator + public SequenceFunctionSplit(@JsonProperty("start") long start, @JsonProperty("stop") long stop) + { + this.start = start; + this.stop = stop; + } + + @JsonProperty + public long getStart() + { + return start; + } + + @JsonProperty + public long getStop() + { + return stop; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return ImmutableMap.builder() + .put("start", start) + .put("stop", stop) + .buildOrThrow(); + } + } + + public static TableFunctionProcessorProvider getSequenceFunctionProcessorProvider() + { + return new TableFunctionProcessorProvider() { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new SequenceFunctionProcessor(((SequenceFunctionHandle) handle).step()); + } + }; + } + + public static class SequenceFunctionProcessor + implements TableFunctionSplitProcessor + { + private final PageBuilder page = new PageBuilder(ImmutableList.of(BIGINT)); + private final long step; + private long start; + private long stop; + private boolean finished; + + public SequenceFunctionProcessor(long step) + { + this.step = step; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split != null) { + SequenceFunctionSplit sequenceSplit = (SequenceFunctionSplit) split; + start = sequenceSplit.getStart(); + stop = sequenceSplit.getStop(); + BlockBuilder block = page.getBlockBuilder(0); + while (start != stop && !page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + start += step; + } + if (!page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + finished = true; + return usedInputAndProduced(page.build()); + } + return usedInputAndProduced(page.build()); + } + + if (finished) { + return FINISHED; + } + + page.reset(); + BlockBuilder block = page.getBlockBuilder(0); + while (start != stop && !page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + start += step; + } + if (!page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + finished = true; + return produced(page.build()); + } + return produced(page.build()); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java index d072762d14dbc..34cb2d7821b13 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java @@ -52,6 +52,7 @@ import com.facebook.presto.spi.tracing.TracerProvider; import com.facebook.presto.spi.ttl.ClusterTtlProviderFactory; import com.facebook.presto.spi.ttl.NodeTtlFetcherFactory; +import com.facebook.presto.spi.tvf.TVFProviderFactory; import com.facebook.presto.spi.type.TypeManagerFactory; import com.facebook.presto.sql.analyzer.AnalyzerProviderManager; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; @@ -340,6 +341,11 @@ public void installCoordinatorPlugin(CoordinatorPlugin plugin) log.info("Registering type manager factory %s", typeManagerFactory.getName()); metadata.getFunctionAndTypeManager().addTypeManagerFactory(typeManagerFactory); } + + for (TVFProviderFactory tvfProviderFactory : plugin.getTVFProviderFactories()) { + log.info("Registering table functions provider factory %s", tvfProviderFactory.getName()); + metadata.getFunctionAndTypeManager().addTVFProviderFactory(tvfProviderFactory); + } } private class MainPluginInstaller diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java index 8bcd62f6006ef..f7f6afc493848 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -32,14 +33,14 @@ public class CloseableSplitSourceProvider { private static final Logger log = Logger.get(CloseableSplitSourceProvider.class); - private final SplitSourceProvider delegate; + private final SplitManager delegate; @GuardedBy("this") private List splitSources = new ArrayList<>(); @GuardedBy("this") private boolean closed; - public CloseableSplitSourceProvider(SplitSourceProvider delegate) + public CloseableSplitSourceProvider(SplitManager delegate) { this.delegate = requireNonNull(delegate, "delegate is null"); } @@ -53,6 +54,15 @@ public synchronized SplitSource getSplits(Session session, TableHandle tableHand return splitSource; } + @Override + public synchronized SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle) + { + checkState(!closed, "split source provider is closed"); + SplitSource splitSource = delegate.getSplitsForTableFunction(session, tableFunctionHandle); + splitSources.add(splitSource); + return splitSource; + } + @Override public synchronized void close() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java index 8cd7efc5e2f9b..d334e5a5e6a70 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java @@ -18,6 +18,7 @@ import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.metadata.TableLayoutResult; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; @@ -104,4 +105,17 @@ private ConnectorSplitManager getConnectorSplitManager(ConnectorId connectorId) checkArgument(result != null, "No split manager for connector '%s'", connectorId); return result; } + + public SplitSource getSplitsForTableFunction(Session session, TableFunctionHandle function) + { + ConnectorId connectorId = function.getConnectorId(); + ConnectorSplitManager splitManager = splitManagers.get(connectorId); + + ConnectorSplitSource source = splitManager.getSplits( + function.getTransactionHandle(), + session.toConnectorSession(connectorId), + function.getFunctionHandle()); + + return new ConnectorAwareSplitSource(connectorId, function.getTransactionHandle(), source); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java index 617fba7093613..30b54174c27b6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -21,4 +22,5 @@ public interface SplitSourceProvider { SplitSource getSplits(Session session, TableHandle tableHandle, SplitSchedulingStrategy splitSchedulingStrategy, WarningCollector warningCollector); + SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index fc383fc928d23..2deb6dc592da2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -36,7 +36,6 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.OperatorNotFoundException; -import com.facebook.presto.metadata.TableFunctionMetadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; @@ -62,6 +61,7 @@ import com.facebook.presto.spi.function.table.Argument; import com.facebook.presto.spi.function.table.ArgumentSpecification; import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.DescribedTableReturnTypeSpecification; import com.facebook.presto.spi.function.table.Descriptor; import com.facebook.presto.spi.function.table.DescriptorArgument; import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; @@ -71,6 +71,7 @@ import com.facebook.presto.spi.function.table.TableArgument; import com.facebook.presto.spi.function.table.TableArgumentSpecification; import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionMetadata; import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.spi.relation.DomainTranslator; @@ -265,7 +266,7 @@ import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.function.FunctionKind.WINDOW; import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; -import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.GenericTableReturnTypeSpecification.GENERIC_TABLE; import static com.facebook.presto.spi.security.ViewSecurity.DEFINER; import static com.facebook.presto.spi.security.ViewSecurity.INVOKER; import static com.facebook.presto.sql.MaterializedViewUtils.buildOwnerSession; @@ -1507,14 +1508,10 @@ protected Scope visitLateral(Lateral node, Optional scope) @Override protected Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional scope) { - TableFunctionMetadata tableFunctionMetadata = metadata.getFunctionAndTypeManager() - .getTableFunctionRegistry() - .resolve(session, node.getName()) - .orElseThrow(() -> new SemanticException( - FUNCTION_NOT_FOUND, - node, - "Table function %s not registered", - node.getName())); + TableFunctionMetadata tableFunctionMetadata = metadata.getFunctionAndTypeManager().resolveTableFunction(session, node.getName()); + if (tableFunctionMetadata == null) { + throw new SemanticException(FUNCTION_NOT_FOUND, node, "Table function %s not registered", node.getName()); + } ConnectorTableFunction function = tableFunctionMetadata.getFunction(); ConnectorId connectorId = tableFunctionMetadata.getConnectorId(); @@ -1608,7 +1605,7 @@ private void verifyRequiredColumns(TableFunctionInvocation node, Map column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .filter(column -> column < 0 || column >= inputScope.getRelationType().getVisibleFieldCount()) .findFirst() .ifPresent(column -> { throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Invalid index: %s of required column from table argument %s", column, name); @@ -1626,7 +1623,7 @@ private void verifyRequiredColumns(TableFunctionInvocation node, Map analyzedProperColumnsDescriptor) { switch (returnTypeSpecification.getReturnType()) { - case ReturnTypeSpecification.OnlyPassThrough.returnType: + case "PASSTHROUGH": if (analysis.isAliased(node)) { // According to SQL standard ISO/IEC 9075-2, 7.6 , p. 409, // table alias is prohibited for a table function with ONLY PASS THROUGH returned type. @@ -1648,7 +1645,7 @@ private Descriptor verifyProperColumnsDescriptor(TableFunctionInvocation node, C throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "A table function with ONLY_PASS_THROUGH return type must have a table argument with pass-through columns."); } return null; - case ReturnTypeSpecification.GenericTable.returnType: + case "GENERIC": // According to SQL standard ISO/IEC 9075-2, 7.6
, p. 409, // table alias is mandatory for a polymorphic table function invocation which produces proper columns. // We don't enforce this requirement. @@ -1664,7 +1661,7 @@ private Descriptor verifyProperColumnsDescriptor(TableFunctionInvocation node, C // so the function's analyze() method should not return the proper columns descriptor. throw new SemanticException(TABLE_FUNCTION_AMBIGUOUS_RETURN_TYPE, node, "Returned relation type for table function %s is ambiguous", node.getName()); } - return ((ReturnTypeSpecification.DescribedTable) returnTypeSpecification).getDescriptor(); + return ((DescribedTableReturnTypeSpecification) returnTypeSpecification).getDescriptor(); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index 7683f51b60973..2326c62b56dfd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -51,6 +51,8 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -296,6 +298,22 @@ public PlanNode visitValues(ValuesNode node, RewriteContext return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + // context is mutable. The leaf node should set the PartitioningHandle. + context.get().addSourceDistribution(node.getId(), SOURCE_DISTRIBUTION, metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 878f46dc04fcb..13cc38caef3fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -69,6 +69,7 @@ import com.facebook.presto.operator.JoinBridgeManager; import com.facebook.presto.operator.JoinOperatorFactory; import com.facebook.presto.operator.JoinOperatorFactory.OuterOperatorFactoryResult; +import com.facebook.presto.operator.LeafTableFunctionOperator; import com.facebook.presto.operator.LimitOperator.LimitOperatorFactory; import com.facebook.presto.operator.LocalPlannerAware; import com.facebook.presto.operator.LookupJoinOperators; @@ -89,6 +90,7 @@ import com.facebook.presto.operator.PartitionFunction; import com.facebook.presto.operator.PartitionedLookupSourceFactory; import com.facebook.presto.operator.PipelineExecutionStrategy; +import com.facebook.presto.operator.RegularTableFunctionPartition; import com.facebook.presto.operator.RemoteProjectOperator.RemoteProjectOperatorFactory; import com.facebook.presto.operator.RowNumberOperator; import com.facebook.presto.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -102,6 +104,7 @@ import com.facebook.presto.operator.StreamingAggregationOperator.StreamingAggregationOperatorFactory; import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TableFinishOperator.PageSinkCommitter; +import com.facebook.presto.operator.TableFunctionOperator; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.TableWriterMergeOperator.TableWriterMergeOperatorFactory; import com.facebook.presto.operator.TaskContext; @@ -145,11 +148,13 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.aggregation.LambdaProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -218,6 +223,7 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -250,6 +256,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -358,9 +365,11 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; +import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -1218,6 +1227,92 @@ public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecuti throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); } + @Override + public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context) + { + TableFunctionProcessorProvider processorProvider = metadata.getFunctionAndTypeManager().getTableFunctionProcessorProvider(node.getHandle()); + + if (!node.getSource().isPresent()) { + OperatorFactory operatorFactory = new LeafTableFunctionOperator.LeafTableFunctionOperatorFactory(context.getNextOperatorId(), node.getId(), processorProvider, node.getHandle().getFunctionHandle()); + return new PhysicalOperation(operatorFactory, makeLayout(node), context, Optional.empty(), UNGROUPED_EXECUTION); + } + + PhysicalOperation source = node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + + int properChannelsCount = node.getProperOutputs().size(); + + long passThroughSourcesCount = node.getPassThroughSpecifications().stream() + .filter(TableFunctionNode.PassThroughSpecification::isDeclaredAsPassThrough) + .count(); + + List> requiredChannels = node.getRequiredVariables().stream() + .map(list -> getChannelsForVariables(list, source.getLayout())) + .collect(toImmutableList()); + + Optional> markerChannels = node.getMarkerVariables() + .map(map -> map.entrySet().stream() + .collect(toImmutableMap(entry -> source.getLayout().get(entry.getKey()), entry -> source.getLayout().get(entry.getValue())))); + + int channel = properChannelsCount; + ImmutableList.Builder passThroughColumnSpecifications = ImmutableList.builder(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + // the table function produces one index channel for each source declared as pass-through. They are laid out after the proper channels. + int indexChannel = specification.isDeclaredAsPassThrough() ? channel++ : -1; + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + passThroughColumnSpecifications.add(new RegularTableFunctionPartition.PassThroughColumnSpecification(column.isPartitioningColumn(), source.getLayout().get(column.getOutputVariables()), indexChannel)); + } + } + + List partitionChannels = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .map(list -> getChannelsForVariables(list, source.getLayout())) + .orElse(ImmutableList.of()); + + List sortChannels = ImmutableList.of(); + List sortOrders = ImmutableList.of(); + if (node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).isPresent()) { + OrderingScheme orderingScheme = node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).orElseThrow(NoSuchElementException::new); + sortChannels = getChannelsForVariables(orderingScheme.getOrderByVariables(), source.getLayout()); + sortOrders = orderingScheme.getOrderingsMap().values().stream().collect(toImmutableList()); + } + + OperatorFactory operator = new TableFunctionOperator.TableFunctionOperatorFactory( + context.getNextOperatorId(), + node.getId(), + processorProvider, + node.getHandle().getFunctionHandle(), + properChannelsCount, + toIntExact(passThroughSourcesCount), + requiredChannels, + markerChannels, + passThroughColumnSpecifications.build(), + node.isPruneWhenEmpty(), + partitionChannels, + getChannelsForVariables(ImmutableList.copyOf(node.getPrePartitioned()), source.getLayout()), + sortChannels, + sortOrders, + node.getPreSorted(), + source.getTypes(), + 10_000, + pagesIndexFactory); + + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + for (int i = 0; i < node.getProperOutputs().size(); i++) { + outputMappings.put(node.getProperOutputs().get(i), i); + } + List passThroughVariables = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(toImmutableList()); + int outputChannel = properChannelsCount; + for (VariableReferenceExpression passThroughVariable : passThroughVariables) { + outputMappings.put(passThroughVariable, outputChannel++); + } + + return new PhysicalOperation(operator, outputMappings.buildOrThrow(), context, source); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { @@ -2942,7 +3037,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl Map aggregationMap = aggregation.getAggregations().entrySet() .stream().collect( - ImmutableMap.toImmutableMap( + toImmutableMap( Map.Entry::getKey, entry -> createAggregation(entry.getValue()))); if (groupingVariables.isEmpty()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index a12b350dcdd1e..d56bdb9b82048 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -84,6 +84,8 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneRedundantProjectionAssignments; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinFilteringSourceColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneUpdateSourceColumns; @@ -123,6 +125,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantLimit; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSort; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSortColumns; +import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTableFunctionProcessor; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopN; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.RemoveTrivialFilters; @@ -137,7 +140,6 @@ import com.facebook.presto.sql.planner.iterative.rule.RewriteConstantArrayContainsToInExpression; import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; -import com.facebook.presto.sql.planner.iterative.rule.RewriteTableFunctionToTableScan; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; import com.facebook.presto.sql.planner.iterative.rule.ScaledWriterRule; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCardinalityMap; @@ -154,6 +156,8 @@ import com.facebook.presto.sql.planner.iterative.rule.TransformDistinctInnerJoinToLeftEarlyOutJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformDistinctInnerJoinToRightEarlyOutJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformExistsApplyToLateralNode; +import com.facebook.presto.sql.planner.iterative.rule.TransformTableFunctionProcessorToTableScan; +import com.facebook.presto.sql.planner.iterative.rule.TransformTableFunctionToTableFunctionProcessor; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedLateralToJoin; @@ -319,6 +323,8 @@ public PlanOptimizers( new PruneValuesColumns(), new PruneWindowColumns(), new PruneLimitColumns(), + new PruneTableFunctionProcessorColumns(), + new PruneTableFunctionProcessorSourceColumns(), new PruneTableScanColumns()); builder.add(new LogicalCteOptimizer(metadata)); @@ -424,6 +430,7 @@ public PlanOptimizers( .addAll(predicatePushDownRules) .addAll(columnPruningRules) .addAll(ImmutableSet.of( + new TransformTableFunctionToTableFunctionProcessor(metadata), new MergeDuplicateAggregation(metadata.getFunctionAndTypeManager()), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -446,6 +453,8 @@ public PlanOptimizers( new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata.getFunctionAndTypeManager()), new PruneOrderByInAggregation(metadata.getFunctionAndTypeManager()), + new RemoveRedundantTableFunctionProcessor(), // must run after TransformTableFunctionToTableFunctionProcessor + new TransformTableFunctionProcessorToTableScan(metadata), // must run after TransformTableFunctionToTableFunctionProcessor new RewriteSpatialPartitioningAggregation(metadata))) .build()), new IterativeOptimizer( @@ -786,7 +795,11 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new RemoveRedundantIdentityProjections(), new PruneRedundantProjectionAssignments())), + ImmutableSet.of( + new RemoveRedundantIdentityProjections(), + new PruneRedundantProjectionAssignments(), + // Re-run RemoveRedundantTableFunctionProcessor after SimplifyPlanWithEmptyInput to optimize empty input tables to empty ValueNode + new RemoveRedundantTableFunctionProcessor())), new PushdownSubfields(metadata, expressionOptimizerManager)); builder.add(predicatePushDown); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate @@ -881,14 +894,6 @@ public PlanOptimizers( costCalculator, ImmutableSet.of(new ScaledWriterRule()))); - builder.add( - new IterativeOptimizer( - metadata, - ruleStats, - statsCalculator, - costCalculator, - ImmutableSet.of(new RewriteTableFunctionToTableScan(metadata)))); - if (!noExchange) { builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 826eaea10044b..4e32a6e6918ed 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -205,6 +205,9 @@ public static PlanNode addOverrideProjection(PlanNode source, PlanNodeIdAllocato || source.getOutputVariables().stream().distinct().count() != source.getOutputVariables().size()) { return source; } + if (source instanceof ProjectNode && ((ProjectNode) source).getAssignments().getMap().equals(variableMap)) { + return source; + } Assignments.Builder assignmentsBuilder = Assignments.builder(); assignmentsBuilder.putAll(source.getOutputVariables().stream().collect(toImmutableMap(identity(), x -> variableMap.containsKey(x) ? variableMap.get(x) : x))); return new ProjectNode(source.getSourceLocation(), planNodeIdAllocator.getNextId(), source, assignmentsBuilder.build(), LOCAL); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index f2c102667e06f..c7deb6a31cdec 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -175,7 +175,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -class QueryPlanner +public class QueryPlanner { private final Analysis analysis; private final VariableAllocator variableAllocator; @@ -891,7 +891,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression * * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed */ - private PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) + public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) { Assignments.Builder assignments = Assignments.builder(); assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); @@ -1713,15 +1713,18 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } - private static List toSymbolReferences(List variables) + public static List toSymbolReferences(List variables) { return variables.stream() - .map(variable -> new SymbolReference( - variable.getSourceLocation().map(location -> new NodeLocation(location.getLine(), location.getColumn())), - variable.getName())) + .map(QueryPlanner::toSymbolReference) .collect(toImmutableList()); } + public static SymbolReference toSymbolReference(VariableReferenceExpression variable) + { + return new SymbolReference(variable.getSourceLocation().map(location -> new NodeLocation(location.getLine(), location.getColumn())), variable.getName()); + } + public static class PlanAndMappings { private final PlanBuilder subPlan; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index be6b7851ec99b..e07fadcbf4666 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; @@ -31,12 +32,14 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.MaterializedViewScanNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; @@ -52,14 +55,17 @@ import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.analyzer.RelationType; +import com.facebook.presto.sql.analyzer.ResolvedField; import com.facebook.presto.sql.analyzer.Scope; -import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CoalesceExpression; @@ -87,11 +93,10 @@ import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SampledRelation; import com.facebook.presto.sql.tree.SetOperation; +import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Table; -import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; import com.facebook.presto.sql.tree.TableFunctionInvocation; -import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; @@ -99,6 +104,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; import com.google.common.collect.UnmodifiableIterator; @@ -119,6 +125,7 @@ import static com.facebook.presto.SystemSessionProperties.getQueryAnalyzerTimeout; import static com.facebook.presto.common.type.TypeUtils.isEnumType; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_PLAN_ERROR; import static com.facebook.presto.spi.StandardErrorCode.QUERY_PLANNING_TIMEOUT; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; @@ -127,7 +134,6 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.NONE; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; @@ -364,51 +370,185 @@ private RelationPlan planMaterializedView(Table node, Analysis.MaterializedViewI return new RelationPlan(mvScanNode, scope, outputVariables); } + /** + * Processes a {@code TableFunctionInvocation} node to construct and return a {@link RelationPlan}. + * This involves preparing the necessary plan nodes, variable mappings, and associated properties + * to represent the execution plan for the invoked table function. + * + * @param node The {@code TableFunctionInvocation} syntax tree node to be processed. + * @param context The SQL planner context used for planning and analysis tasks. + * @return A {@link RelationPlan} encapsulating the execution plan for the table function invocation. + */ @Override protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, SqlPlannerContext context) { - node.getArguments().stream() - .forEach(argument -> { - if (argument.getValue() instanceof TableFunctionTableArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions"); - } - if (argument.getValue() instanceof TableFunctionDescriptorArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions"); - } - }); Analysis.TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) + .map(field -> variableAllocator.newVariable(getSourceLocation(node), field.getName().orElse("field"), field.getType())) + .collect(toImmutableList()); - // TODO handle input relations: - // 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested. - // 2. for each input relation, prepare the TableArgumentProperties record, consisting of: - // - row or set semantics (from the actualArgument) - // - prune when empty property (from the actualArgument) - // - pass through columns property (from the actualArgument) - // - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources) - // TODO add - argument name - // TODO add - mapping column name => Symbol // TODO mind the fields without names and duplicate field names in RelationType - List sources = ImmutableList.of(); - List inputRelationsProperties = ImmutableList.of(); - - Scope scope = analysis.getScope(node); + outputVariables.addAll(properOutputs); - ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); - for (Field field : scope.getRelationType().getAllFields()) { - VariableReferenceExpression variable = variableAllocator.newVariable(getSourceLocation(node), field.getName().get(), field.getType()); - outputVariablesBuilder.add(variable); - } + processTableArguments(context, functionAnalysis, outputVariables, sources, sourceProperties); - List outputVariables = outputVariablesBuilder.build(); PlanNode root = new TableFunctionNode( idAllocator.getNextId(), functionAnalysis.getFunctionName(), functionAnalysis.getArguments(), - outputVariablesBuilder.build(), - sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), - inputRelationsProperties, - new TableFunctionHandle(functionAnalysis.getConnectorId(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), + new TableFunctionHandle( + functionAnalysis.getConnectorId(), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); + + return new RelationPlan(root, analysis.getScope(node), outputVariables.build()); + } + + private void processTableArguments(SqlPlannerContext context, + Analysis.TableFunctionInvocationAnalysis functionAnalysis, + ImmutableList.Builder outputVariables, + ImmutableList.Builder sources, + ImmutableList.Builder sourceProperties) + { + QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); + // process sources in order of argument declarations + for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); + PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); + + int[] fieldIndexForVisibleColumn = getFieldIndexesForVisibleColumns(sourcePlan); + + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(column -> fieldIndexForVisibleColumn[column]) + .map(sourcePlan::getVariable) + .collect(toImmutableList()); + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + List partitioningColumns = tableArgument.getPartitionBy().get(); + for (Expression partitionColumn : partitioningColumns) { + if (!sourcePlanBuilder.canTranslate(partitionColumn)) { + ResolvedField partition = sourcePlan.getScope().tryResolveField(partitionColumn).orElseThrow(() -> new PrestoException(INVALID_PLAN_ERROR, "Missing equivalent alias")); + sourcePlanBuilder.getTranslations().put(partitionColumn, sourcePlan.getVariable(partition.getRelationFieldIndex())); + } + } + QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } - return new RelationPlan(root, scope, outputVariables); + // order by + Optional orderBy = getOrderingScheme(tableArgument, sourcePlanBuilder, sourcePlan); + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + // add output symbols passed from the table argument + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); + addPassthroughColumns(outputVariables, tableArgument, sourcePlan, specification, passThroughColumns, sourcePlanBuilder); + sources.add(sourcePlanBuilder.getRoot()); + + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + new PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), + requiredColumns, + specification)); + } + } + + private static int[] getFieldIndexesForVisibleColumns(RelationPlan sourcePlan) + { + // required columns are a subset of visible columns of the source. remap required column indexes to field indexes in source relation type. + RelationType sourceRelationType = sourcePlan.getScope().getRelationType(); + int[] fieldIndexForVisibleColumn = new int[sourceRelationType.getVisibleFieldCount()]; + int visibleColumn = 0; + for (int i = 0; i < sourceRelationType.getAllFieldCount(); i++) { + if (!sourceRelationType.getFieldByIndex(i).isHidden()) { + fieldIndexForVisibleColumn[visibleColumn] = i; + visibleColumn++; + } + } + return fieldIndexForVisibleColumn; + } + + private static Optional getOrderingScheme(Analysis.TableArgumentAnalysis tableArgument, PlanBuilder sourcePlanBuilder, RelationPlan sourcePlan) + { + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + List sortItems = tableArgument.getOrderBy().get().getSortItems(); + + // Ensure all ORDER BY columns can be translated (populate missing translations if needed) + for (SortItem sortItem : sortItems) { + Expression sortKey = sortItem.getSortKey(); + if (!sourcePlanBuilder.canTranslate(sortKey)) { + Optional resolvedField = sourcePlan.getScope().tryResolveField(sortKey); + resolvedField.ifPresent(field -> sourcePlanBuilder.getTranslations().put( + sortKey, + sourcePlan.getVariable(field.getRelationFieldIndex()))); + } + } + + // The ordering symbols are coerced + List coerced = sortItems.stream() + .map(SortItem::getSortKey) + .map(sourcePlanBuilder::translate) + .collect(toImmutableList()); + + List sortOrders = sortItems.stream() + .map(PlannerUtils::toSortOrder) + .collect(toImmutableList()); + + orderBy = Optional.of(PlannerUtils.toOrderingScheme(coerced, sortOrders)); + } + return orderBy; + } + + private static void addPassthroughColumns(ImmutableList.Builder outputVariables, + Analysis.TableArgumentAnalysis tableArgument, RelationPlan sourcePlan, + Optional specification, + ImmutableList.Builder passThroughColumns, + PlanBuilder sourcePlanBuilder) + { + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputVariables.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(variable -> new PassThroughColumn(variable, partitionBy.contains(variable))) + .forEach(passThroughColumns::add); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + .map(sourcePlanBuilder::translate) + // the original symbols for partitioning columns, not coerced + .forEach(variable -> { + outputVariables.add(variable); + passThroughColumns.add(new PassThroughColumn(variable, true)); + }); + } } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java index abb784cdaa298..471c797c426a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java @@ -22,9 +22,11 @@ import com.facebook.presto.spi.plan.SpatialJoinNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.NoSuchElementException; import java.util.function.Consumer; public class SchedulingOrderVisitor @@ -88,5 +90,17 @@ public Void visitTableScan(TableScanNode node, Consumer schedulingOr schedulingOrder.accept(node.getId()); return null; } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Consumer schedulingOrder) + { + if (!node.getSource().isPresent()) { + schedulingOrder.accept(node.getId()); + } + else { + node.getSource().orElseThrow(NoSuchElementException::new).accept(this, schedulingOrder); + } + return null; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java index 28d2bc98b1efb..387271aa94e16 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java @@ -61,6 +61,7 @@ import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -69,6 +70,7 @@ import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.function.Supplier; import static com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.GROUPED_SCHEDULING; @@ -283,6 +285,21 @@ public Map visitRowNumber(RowNumberNode node, Context c return node.getSource().accept(this, context); } + @Override + public Map visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) + { + if (!node.getSource().isPresent()) { + // this is a source node, so produce splits + SplitSource splitSource = splitSourceProvider.getSplits( + session, + node.getHandle()); + splitSources.add(splitSource); + return ImmutableMap.of(node.getId(), splitSource); + } + + return node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + } + @Override public Map visitTopNRowNumber(TopNRowNumberNode node, Context context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..c95212bf38c0d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** + * TableFunctionProcessorNode has two kinds of outputs: + * - proper outputs, which are the columns produced by the table function, + * - pass-through outputs, which are the columns copied from table arguments. + * This rule filters out unreferenced pass-through symbols. + * Unreferenced proper symbols are not pruned, because there is currently no way + * to communicate to the table function the request for not producing certain columns. + * // TODO prune table function's proper outputs + * Example: + *
+ * - Project
+ *   assignments={proper->proper1}
+ *  - TableFunctionProcessor
+ *    properOutputs=[proper1, proper2]
+ *    passThroughSymbols=[[passthrough1],[passthrough2]]
+ * 
+ * is transformed into + *
+ * - Project
+ *   assignments={proper->proper1}
+ *   - TableFunctionProcessor
+ *     properOutputs=[proper1, proper2]
+ *     passThroughSymbols=[]
+ * 
+ */ +public class PruneTableFunctionProcessorColumns + extends ProjectOffPushDownRule +{ + public PruneTableFunctionProcessorColumns() + { + super(tableFunctionProcessor()); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, TableFunctionProcessorNode node, Set referencedOutputs) + { + List prunedPassThroughSpecifications = node.getPassThroughSpecifications().stream() + .map(sourceSpecification -> { + List prunedPassThroughColumns = sourceSpecification.getColumns().stream() + .filter(column -> referencedOutputs.contains(column.getOutputVariables())) + .collect(toImmutableList()); + return new TableFunctionNode.PassThroughSpecification(sourceSpecification.isDeclaredAsPassThrough(), prunedPassThroughColumns); + }) + .collect(toImmutableList()); + + int originalPassThroughCount = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + int prunedPassThroughCount = prunedPassThroughSpecifications.stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + if (originalPassThroughCount == prunedPassThroughCount) { + return Optional.empty(); + } + + return Optional.of(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + node.getSource(), + node.isPruneWhenEmpty(), + prunedPassThroughSpecifications, + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..d90f668d4c98f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.Maps.filterKeys; + +/** + * This rule prunes unreferenced outputs of TableFunctionProcessorNode. + * First, it extracts all symbols required for: + * - pass-through + * - table function computation + * - partitioning and ordering (including the hashSymbol) + * Next, a mapping of input symbols to marker symbols is updated + * so that it only contains mappings for the required symbols. + * Last, all the remaining marker symbols are added to the collection + * of required symbols. + * Any source output symbols not included in the required symbols + * can be pruned. + * Example: + *
+ * - TableFunctionProcessor
+ *   properOutputs=[proper]
+ *   passThroughSymbols=[[passthrough1],[passthrough2]]
+ *   requiredSymbols=[[require1], [require2]]
+ *   specification=[partition={[partition1]} orderby={[order1 ASC_NULLS_LAST]}]
+ *   hashSymbol=[hash]
+ *   markerVariables={passthrough1->marker1, require1->marker1, partition1->marker1, order1->marker1, passthrough2->marker2, require2->marker, unreferenced->marker2}
+ *   - Source (which produces passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash, unreferenced)
+ * 
+ * is transformed into + *
+ * - TableFunctionProcessor
+ *   properOutputs=[proper]
+ *   passThroughSymbols=[[passthrough1],[passthrough2]]
+ *   requiredSymbols=[[require1], [require2]]
+ *   specification=[partition={[partition1]} orderby={[order1 ASC_NULLS_LAST]}]
+ *   hashSymbol=[hash]
+ *   markerVariables={passthrough1->marker1, require1->marker1, partition1->marker1, order1->marker1, passthrough2->marker2, require2->marker}
+ *   - Project
+ *     assignments=[passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash]
+ *     - Source (which produces passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash, unreferenced)
+ * 
+ */ +public class PruneTableFunctionProcessorSourceColumns + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!node.getSource().isPresent()) { + return Result.empty(); + } + + ImmutableSet.Builder requiredInputs = ImmutableSet.builder(); + + node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(requiredInputs::add); + + node.getRequiredVariables() + .forEach(requiredInputs::addAll); + + node.getSpecification().ifPresent(specification -> { + requiredInputs.addAll(specification.getPartitionBy()); + specification.getOrderingScheme().ifPresent(orderingScheme -> requiredInputs.addAll(orderingScheme.getOrderByVariables())); + }); + + node.getHashSymbol().ifPresent(requiredInputs::add); + + Optional> updatedMarkerSymbols = node.getMarkerVariables() + .map(mapping -> filterKeys(mapping, requiredInputs.build()::contains)); + + updatedMarkerSymbols.ifPresent(mapping -> requiredInputs.addAll(mapping.values())); + + return restrictOutputs(context.getIdAllocator(), node.getSource().orElseThrow(NoSuchElementException::new), requiredInputs.build()) + .map(child -> Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + updatedMarkerSymbols, + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle()))) + .orElse(Result.empty()); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java new file mode 100644 index 0000000000000..48f58bce5952a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableList; + +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMost; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; + +/** + * Table function can take multiple table arguments. Each argument is either "prune when empty" or "keep when empty". + * "Prune when empty" means that if this argument has no rows, the function result is empty, so the function can be + * removed from the plan, and replaced with empty values. + * "Keep when empty" means that even if the argument has no rows, the function should still be executed, and it can + * return a non-empty result. + * All the table arguments are combined into a single source of a TableFunctionProcessorNode. If either argument is + * "prune when empty", the overall result is "prune when empty". This rule removes a redundant TableFunctionProcessorNode + * based on the "prune when empty" property. + */ +public class RemoveRedundantTableFunctionProcessor + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (node.isPruneWhenEmpty() && node.getSource().isPresent()) { + if (isAtMost(node.getSource().orElseThrow(NoSuchElementException::new), context.getLookup(), 0)) { + return Result.ofPlanNode( + new ValuesNode(node.getSourceLocation(), + node.getId(), + node.getOutputVariables(), + ImmutableList.of(), + Optional.empty())); + } + } + + return Result.empty(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java new file mode 100644 index 0000000000000..1857dafa9493c --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.operator.table.ExcludeColumns.ExcludeColumnsFunctionHandle; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.NoSuchElementException; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterators.getOnlyElement; +/** + * Rewrite a TableFunctionProcessorNode into a Project node if the table function is exclude_columns. + *
+ * - TableFunctionProcessorNode
+ *   propperOutputs=[A, B]
+ *   passthroughColumns=[C, D]
+ *   - (input) plan which produces symbols [A, B, C, D]
+ * 
+ * into + *
+ * - Project
+ *   assignments={A, B, C, D}
+ *   - (input) plan which produces symbols [A, B, C, D]
+ * 
+ */ +public class RewriteExcludeColumnsFunctionToProjection + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!(node.getHandle().getFunctionHandle() instanceof ExcludeColumnsFunctionHandle)) { + return Result.empty(); + } + + List inputSymbols = getOnlyElement(node.getRequiredVariables().iterator()); + List outputSymbols = node.getOutputVariables(); + + checkState(inputSymbols.size() == outputSymbols.size(), "inputSymbols size differs from outputSymbols size"); + Assignments.Builder assignments = Assignments.builder(); + for (int i = 0; i < outputSymbols.size(); i++) { + assignments.put(outputSymbols.get(i), inputSymbols.get(i)); + } + + return Result.ofPlanNode(new ProjectNode( + node.getId(), + node.getSource().orElseThrow(NoSuchElementException::new), + assignments.build())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java similarity index 61% rename from presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java rename to presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java index 2418377c7ac53..6cbf4378b334b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java @@ -23,7 +23,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -31,65 +31,70 @@ import static com.facebook.presto.matching.Pattern.empty; import static com.facebook.presto.sql.planner.plan.Patterns.sources; -import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; /* - * This process converts connector-resolvable TableFunctionNodes into equivalent - * TableScanNodes by invoking the connector’s applyTableFunction() during planning. - * It allows table-valued functions whose results can be expressed as a ConnectorTableHandle - * to be treated like regular scans and benefit from normal scan optimizations. + * This rule converts connector-resolvable TableFunctionProcessorNodes into equivalent + * TableScanNodes by invoking the connector's applyTableFunction() method during query planning. + * + * It enables table-valued functions whose results can be represented as a ConnectorTableHandle + * to be treated like regular table scans, allowing them to benefit from standard scan optimizations. * * Example: * Before Transformation: * TableFunction(my_function(arg1, arg2)) * * After Transformation: - * TableScan(my_function(arg1, arg2)).applyTableFunction_tableHandle) - * assignments: {outputVar1 -> my_function(arg1, arg2)).applyTableFunction_colHandle1, - * outputVar2 -> my_function(arg1, arg2)).applyTableFunction_colHandle2} + * TableScan(my_function(arg1, arg2)) + * assignments: { + * outputVar1 -> my_function(arg1, arg2)_colHandle1, + * outputVar2 -> my_function(arg1, arg2)_colHandle2 + * } */ -public class RewriteTableFunctionToTableScan - implements Rule +public class TransformTableFunctionProcessorToTableScan + implements Rule { - private static final Pattern PATTERN = tableFunction() + private static final Pattern PATTERN = tableFunctionProcessor() .with(empty(sources())); private final Metadata metadata; - public RewriteTableFunctionToTableScan(Metadata metadata) + public TransformTableFunctionProcessorToTableScan(Metadata metadata) { this.metadata = requireNonNull(metadata, "metadata is null"); } @Override - public Pattern getPattern() + public Pattern getPattern() { return PATTERN; } @Override - public Result apply(TableFunctionNode tableFunctionNode, Captures captures, Context context) + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) { - Optional> result = metadata.applyTableFunction(context.getSession(), tableFunctionNode.getHandle()); + Optional> result = metadata.applyTableFunction(context.getSession(), node.getHandle()); if (!result.isPresent()) { return Result.empty(); } List columnHandles = result.get().getColumnHandles(); - checkState(tableFunctionNode.getOutputVariables().size() == columnHandles.size(), "returned table does not match the node's output"); + checkState(node.getOutputVariables().size() == columnHandles.size(), + "Connector returned %s columns but TableFunctionProcessorNode expects %s outputs", + columnHandles.size(), node.getOutputVariables().size()); ImmutableMap.Builder assignments = ImmutableMap.builder(); for (int i = 0; i < columnHandles.size(); i++) { - assignments.put(tableFunctionNode.getOutputVariables().get(i), columnHandles.get(i)); + assignments.put(node.getOutputVariables().get(i), columnHandles.get(i)); } return Result.ofPlanNode(new TableScanNode( - tableFunctionNode.getSourceLocation(), - tableFunctionNode.getId(), + node.getSourceLocation(), + node.getId(), result.get().getTableHandle(), - tableFunctionNode.getOutputVariables(), + node.getOutputVariables(), assignments.buildOrThrow(), TupleDomain.all(), TupleDomain.all(), Optional.empty())); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java new file mode 100644 index 0000000000000..8d143ea0d006e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java @@ -0,0 +1,1032 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.WindowNode.Frame; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.plan.JoinType.FULL; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.ROWS; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.relational.Expressions.coalesce; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

+ * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

+ * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

+ * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

+ * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

+ * - TableFunction foo
+ *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
+ *      - source T2(a2, b2) PARTITION BY a2
+ * 
+ * Is transformed into: + *
+ * - TableFunctionDataProcessor foo
+ *      PARTITION BY (a1, a2), ORDER BY combined_row_number
+ *      - Project
+ *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
+ *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
+ *          - Project
+ *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
+ *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
+ *              - FULL Join
+ *                  [table1_row_number = table2_row_number OR
+ *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
+ *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
+ *                  - Window [PARTITION BY a1 ORDER BY b1]
+ *                      table1_row_number <= row_number()
+ *                      table1_partition_size <= count()
+ *                          - source T1(a1, b1)
+ *                  - Window [PARTITION BY a2]
+ *                      table2_row_number <= row_number()
+ *                      table2_partition_size <= count()
+ *                          - source T2(a2, b2)
+ * 
+ */ +public class TransformTableFunctionToTableFunctionProcessor + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public TransformTableFunctionToTableFunctionProcessor(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.isPruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.getSpecification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + + // Create call expression for row_number + FunctionHandle rowNumberFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("row_number")), + ImmutableList.of()); + + FunctionMetadata rowNumberFunctionMetadata = functionAndTypeManager.getFunctionMetadata(rowNumberFunctionHandle); + CallExpression rowNumberFunction = new CallExpression("row_number", rowNumberFunctionHandle, functionAndTypeManager.getType(rowNumberFunctionMetadata.getReturnType()), ImmutableList.of()); + + // Create call expression for count + FunctionHandle countFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("count")), + ImmutableList.of()); + + FunctionMetadata countFunctionMetadata = functionAndTypeManager.getFunctionMetadata(countFunctionHandle); + CallExpression countFunction = new CallExpression("count", countFunctionHandle, functionAndTypeManager.getType(countFunctionMetadata.getReturnType()), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context, metadata)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithVariables finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithVariables first = intermediateResultSources.get(0); + NodeWithVariables second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context, metadata); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithVariables joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context, metadata); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + VariableReferenceExpression finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context, metadata); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.variableToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + ImmutableList.Builder newOrderings = ImmutableList.builder(); + newOrderings.add(new Ordering(finalRowNumberSymbol, ASC_NULLS_LAST)); + Optional finalOrderBy = Optional.of(new OrderingScheme(newOrderings.build())); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredVariables = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredVariables, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithVariables planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + VariableReferenceExpression rowNumber = context.getVariableAllocator().newVariable(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputVariables().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + VariableReferenceExpression partitionSize = context.getVariableAllocator().newVariable(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode innerWindow = new WindowNode( + source.getSourceLocation(), + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + PlanNode window = new WindowNode( + innerWindow.getSourceLocation(), + context.getIdAllocator().getNextId(), + innerWindow, + specification, + ImmutableMap.of( + partitionSize, new WindowNode.Function(countFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithVariables(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithVariables copartition( + List sourceList, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context, + Metadata metadata) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithVariables first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithVariables second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context, metadata); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithVariables copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + NodeWithVariables next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context, metadata); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + } + + private static JoinedNodes copartition(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + + Optional copartitionConjuncts = Streams.zip( + left.partitionBy.stream(), + right.partitionBy.stream(), + (leftColumn, rightColumn) -> new CallExpression("NOT", + functionResolution.notFunction(), + BOOLEAN, + ImmutableList.of( + new CallExpression(IS_DISTINCT_FROM.name(), + functionResolution.comparisonFunction(IS_DISTINCT_FROM, leftColumn.getType(), rightColumn.getType()), + BOOLEAN, + ImmutableList.of(leftColumn, rightColumn))))) + .map(expr -> expr) + .reduce((expr, conjunct) -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(expr, conjunct))); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + + SpecialFormExpression orExpression = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + RowExpression joinCondition = copartitionConjuncts.map( + conjunct -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(conjunct, orExpression))) + .orElse(orExpression); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context, + Metadata metadata) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftRowNumber(), + copartitionedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftPartitionSize(), + copartitionedNodes.rightPartitionSize())); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + VariableReferenceExpression leftColumn = copartitionedNodes.leftPartitionBy().get(i); + VariableReferenceExpression rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getVariableAllocator().getVariables().get(leftColumn.getName()); + + VariableReferenceExpression joinedColumn = context.getVariableAllocator().newVariable("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, coalesce(leftColumn, rightColumn)); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putAll( + copartitionedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + RowExpression joinCondition = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context, Metadata metadata) + { + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftRowNumber(), + joinedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftPartitionSize(), + joinedNodes.rightPartitionSize())); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putAll( + joinedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set variables, VariableReferenceExpression referenceSymbol, Context context, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll( + node.getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))); + + ImmutableMap.Builder variablesToMarkers = ImmutableMap.builder(); + + for (VariableReferenceExpression variable : variables) { + VariableReferenceExpression marker = context.getVariableAllocator().newVariable("marker", BIGINT); + variablesToMarkers.put(variable, marker); + RowExpression ifExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + EQUAL.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.EQUAL, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of(variable, referenceSymbol)), + variable, + new ConstantExpression(null, BIGINT))); + assignments.put(marker, ifExpression); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, variablesToMarkers.buildOrThrow()); + } + + private static class SourceWithProperties + { + private final PlanNode source; + private final TableArgumentProperties properties; + + public SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + this.source = requireNonNull(source, "source is null"); + this.properties = requireNonNull(properties, "properties is null"); + } + + public PlanNode source() + { + return source; + } + + public TableArgumentProperties properties() + { + return properties; + } + } + + public static final class NodeWithVariables + { + private final PlanNode node; + private final VariableReferenceExpression rowNumber; + private final VariableReferenceExpression partitionSize; + private final List partitionBy; + private final boolean pruneWhenEmpty; + private final Map rowNumberSymbolsMapping; + + public NodeWithVariables(PlanNode node, VariableReferenceExpression rowNumber, VariableReferenceExpression partitionSize, + List partitionBy, boolean pruneWhenEmpty, + Map rowNumberSymbolsMapping) + { + this.node = requireNonNull(node, "node is null"); + this.rowNumber = requireNonNull(rowNumber, "rowNumber is null"); + this.partitionSize = requireNonNull(partitionSize, "partitionSize is null"); + this.partitionBy = ImmutableList.copyOf(partitionBy); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + + public PlanNode node() + { + return node; + } + + public VariableReferenceExpression rowNumber() + { + return rowNumber; + } + + public VariableReferenceExpression partitionSize() + { + return partitionSize; + } + + public List partitionBy() + { + return partitionBy; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public Map rowNumberSymbolsMapping() + { + return rowNumberSymbolsMapping; + } + } + + private static class JoinedNodes + { + private final PlanNode joinedNode; + private final VariableReferenceExpression leftRowNumber; + private final VariableReferenceExpression leftPartitionSize; + private final List leftPartitionBy; + private final boolean leftPruneWhenEmpty; + private final Map leftRowNumberSymbolsMapping; + private final VariableReferenceExpression rightRowNumber; + private final VariableReferenceExpression rightPartitionSize; + private final List rightPartitionBy; + private final boolean rightPruneWhenEmpty; + private final Map rightRowNumberSymbolsMapping; + + public JoinedNodes( + PlanNode joinedNode, + VariableReferenceExpression leftRowNumber, + VariableReferenceExpression leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + VariableReferenceExpression rightRowNumber, + VariableReferenceExpression rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + this.joinedNode = requireNonNull(joinedNode, "joinedNode is null"); + this.leftRowNumber = requireNonNull(leftRowNumber, "leftRowNumber is null"); + this.leftPartitionSize = requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + this.leftPartitionBy = ImmutableList.copyOf(requireNonNull(leftPartitionBy, "leftPartitionBy is null")); + this.leftPruneWhenEmpty = leftPruneWhenEmpty; + this.leftRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(leftRowNumberSymbolsMapping, "leftRowNumberSymbolsMapping is null")); + this.rightRowNumber = requireNonNull(rightRowNumber, "rightRowNumber is null"); + this.rightPartitionSize = requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + this.rightPartitionBy = ImmutableList.copyOf(requireNonNull(rightPartitionBy, "rightPartitionBy is null")); + this.rightPruneWhenEmpty = rightPruneWhenEmpty; + this.rightRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(rightRowNumberSymbolsMapping, "rightRowNumberSymbolsMapping is null")); + } + + public PlanNode joinedNode() + { + return joinedNode; + } + public VariableReferenceExpression leftRowNumber() + { + return leftRowNumber; + } + public VariableReferenceExpression leftPartitionSize() + { + return leftPartitionSize; + } + public List leftPartitionBy() + { + return leftPartitionBy; + } + public boolean leftPruneWhenEmpty() + { + return leftPruneWhenEmpty; + } + public Map leftRowNumberSymbolsMapping() + { + return leftRowNumberSymbolsMapping; + } + public VariableReferenceExpression rightRowNumber() + { + return rightRowNumber; + } + public VariableReferenceExpression rightPartitionSize() + { + return rightPartitionSize; + } + public List rightPartitionBy() + { + return rightPartitionBy; + } + public boolean rightPruneWhenEmpty() + { + return rightPruneWhenEmpty; + } + public Map rightRowNumberSymbolsMapping() + { + return rightRowNumberSymbolsMapping; + } + } + + private static class NodeWithMarkers + { + private final PlanNode node; + private final Map variableToMarker; + + public NodeWithMarkers(PlanNode node, Map variableToMarker) + { + this.node = requireNonNull(node, "node is null"); + this.variableToMarker = ImmutableMap.copyOf(requireNonNull(variableToMarker, "symbolToMarker is null")); + } + + public PlanNode node() + { + return node; + } + + public Map variableToMarker() + { + return variableToMarker; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 0b7021f66baf4..047326398832a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -80,6 +80,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; @@ -100,6 +101,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -417,7 +419,59 @@ public PlanWithProperties visitWindow(WindowNode node, PreferredProperties prefe @Override public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredProperties preferredProperties) { - throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, PreferredProperties preferredProperties) + { + if (!node.getSource().isPresent()) { + return new PlanWithProperties(node, deriveProperties(node, ImmutableList.of())); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. A single source with row semantics can be distributed arbitrarily. + PlanWithProperties child = planChild(node, PreferredProperties.any()); + return rebaseAndDeriveProperties(node, child); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification().orElseThrow(NoSuchElementException::new) + .getOrderingScheme() + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + + PlanWithProperties child = planChild(node, PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(partitionBy), desiredProperties)); + + // TODO do not gather if already gathered + if (!node.isPruneWhenEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else if (!isStreamPartitionedOn(child.getProperties(), partitionBy) && + !isNodePartitionedOn(child.getProperties(), partitionBy)) { + if (partitionBy.isEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else { + child = withDerivedProperties( + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode(), Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionBy), node.getHashSymbol()), + child.getProperties()); + } + } + + return rebaseAndDeriveProperties(node, child); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 1081944a4b064..46d17de0d459c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -59,6 +60,8 @@ import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.collect.ImmutableList; @@ -67,6 +70,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; @@ -111,6 +115,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -500,6 +505,87 @@ public PlanWithProperties visitDelete(DeleteNode node, StreamPreferredProperties return deriveProperties(result, child.getProperties()); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, StreamPreferredProperties parentPreferences) + { + if (!node.getSource().isPresent()) { + return deriveProperties(node, ImmutableList.of()); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. Source's properties do not hold after the TableFunctionProcessorNode + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), StreamPreferredProperties.any(), StreamPreferredProperties.any()); + return rebaseAndDeriveProperties(node, ImmutableList.of(child)); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + StreamPreferredProperties childRequirements; + if (!node.isPruneWhenEmpty()) { + childRequirements = singleStream(); + } + else { + childRequirements = parentPreferences + .constrainTo(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()) + .withDefaultParallelism(session) + .withPartitioning(partitionBy); + } + + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), childRequirements, childRequirements); + + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification() + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); + + Set prePartitionedInputs = ImmutableSet.of(); + if (!partitionBy.isEmpty()) { + Optional> groupingRequirement = matchIterator.next(); + Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); + prePartitionedInputs = partitionBy.stream() + .filter(symbol -> !unPartitionedInputs.contains(symbol)) + .collect(toImmutableSet()); + } + + int preSortedOrderPrefix = 0; + if (prePartitionedInputs.equals(ImmutableSet.copyOf(partitionBy))) { + while (matchIterator.hasNext() && !matchIterator.next().isPresent()) { + preSortedOrderPrefix++; + } + } + + TableFunctionProcessorNode result = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child.getNode()), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + prePartitionedInputs, + preSortedOrderPrefix, + node.getHashSymbol(), + node.getHandle()); + + return deriveProperties(result, child.getProperties()); + } + @Override public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index 9dff1d3fb86d9..7dc181ca661e0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.UniqueProperty; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -73,6 +74,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -111,6 +114,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.stream.Collectors.toMap; public class PropertyDerivations @@ -287,6 +291,50 @@ public ActualProperties visitWindow(WindowNode node, List inpu .build(); } + @Override + public ActualProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public ActualProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + ImmutableList.Builder> localProperties = ImmutableList.builder(); + + if (node.getSource().isPresent()) { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + // Only the partitioning properties of the source are passed-through, because the pass-through mechanism preserves the partitioning values. + // Sorting properties might be broken because input rows can be shuffled or nulls can be inserted as the result of pass-through. + // Constant properties might be broken because nulls can be inserted as the result of pass-through. + if (!node.getPrePartitioned().isEmpty()) { + GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitioned()); + for (LocalProperty localProperty : properties.getLocalProperties()) { + if (!prePartitionedProperty.isSimplifiedBy(localProperty)) { + break; + } + localProperties.add(localProperty); + } + } + } + + List partitionBy = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .orElse(ImmutableList.of()); + if (!partitionBy.isEmpty()) { + localProperties.add(new GroupingProperty<>(partitionBy)); + } + + // TODO add global single stream property when there's Specification present with no partitioning columns + + return ActualProperties.builder() + .local(localProperties.build()) + .build() + // Crop properties to output columns. + .translateVariable(variable -> node.getOutputVariables().contains(variable) ? Optional.of(variable) : Optional.empty()); + } + @Override public ActualProperties visitGroupId(GroupIdNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 9ec4a67577777..201ec219823af 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -68,6 +68,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -1084,5 +1085,25 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext> context) + { + return node.getSource().map(source -> new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(context.rewrite(source, ImmutableSet.copyOf(source.getOutputVariables()))), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle() + )).orElse(node); + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java index ffd4806665c2c..cd4d5207fccc8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; @@ -102,6 +103,12 @@ public Range visitEnforceSingleRow(EnforceSingleRowNode node, Void context return Range.singleton(1L); } + @Override + public Range visitWindow(WindowNode node, Void context) + { + return node.getSource().accept(this, null); + } + @Override public Range visitAggregation(AggregationNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 612fb584566bb..79f1d6b1a9260 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -63,6 +63,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -71,11 +73,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -607,6 +611,32 @@ public StreamProperties visitWindow(WindowNode node, List inpu return Iterables.getOnlyElement(inputProperties); } + @Override + public StreamProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public StreamProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + if (!node.getSource().isPresent()) { + return StreamProperties.singleStream(); // TODO allow multiple; return partitioning properties + } + + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + + Set passThroughInputs = Sets.intersection(ImmutableSet.copyOf(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()), ImmutableSet.copyOf(node.getOutputVariables())); + StreamProperties translatedProperties = properties.translate(column -> { + if (passThroughInputs.contains(column)) { + return Optional.of(column); + } + return Optional.empty(); + }); + + return translatedProperties.unordered(true); + } + @Override public StreamProperties visitRowNumber(RowNumberNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 6ab46ef344a13..824c392ce9619 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.ExchangeEncoding; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; @@ -41,6 +42,8 @@ import com.facebook.presto.sql.planner.plan.MergeProcessorNode; import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -55,12 +58,14 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -146,6 +151,27 @@ public RowExpression rewriteVariableReference(VariableReferenceExpression variab }, value); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newOrderings = ImmutableList.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + VariableReferenceExpression variable = orderingScheme.getOrderBy().get(i).getVariable(); + VariableReferenceExpression canonical = map(variable); + if (added.add(canonical)) { + newOrderings.add(new Ordering(canonical, orderingScheme.getOrdering(variable))); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newOrderings.build()), newPreSorted); + } + public OrderingScheme map(OrderingScheme orderingScheme) { // SymbolMapper inlines symbol with multiple level reference (SymbolInliner only inline single level). @@ -388,6 +414,68 @@ public TableWriterMergeNode map(TableWriterMergeNode node, PlanNode source) node.getStatisticsAggregation().map(this::map)); } + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + // rewrite and deduplicate pass-through specifications + // note: Potentially, pass-through symbols from different sources might be recognized as semantically identical, and rewritten + // to the same symbol. Currently, we retrieve the first occurrence of a symbol, and skip all the following occurrences. + // For better performance, we could pick the occurrence with "isPartitioningColumn" property, since the pass-through mechanism + // is more efficient for partitioning columns which are guaranteed to be constant within partition. + // TODO choose a partitioning column to be retrieved while deduplicating + ImmutableList.Builder newPassThroughSpecifications = ImmutableList.builder(); + Set newPassThroughVariables = new HashSet<>(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + ImmutableList.Builder newColumns = ImmutableList.builder(); + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + VariableReferenceExpression newVariable = map(column.getOutputVariables()); + if (newPassThroughVariables.add(newVariable)) { + newColumns.add(new TableFunctionNode.PassThroughColumn(newVariable, column.isPartitioningColumn())); + } + } + newPassThroughSpecifications.add(new TableFunctionNode.PassThroughSpecification(specification.isDeclaredAsPassThrough(), newColumns.build())); + } + + // rewrite required symbols without deduplication. the table function expects specific input layout + List> newRequiredVariables = node.getRequiredVariables().stream() + .map(list -> list.stream() + .map(this::map) + .collect(toImmutableList())) + .collect(toImmutableList()); + + // rewrite and deduplicate marker mapping + Optional> newMarkerVariables = node.getMarkerVariables() + .map(mapping -> mapping.entrySet().stream() + .collect(toImmutableMap( + entry -> map(entry.getKey()), + entry -> map(entry.getValue()), + (first, second) -> { + checkState(first.equals(second), "Ambiguous marker symbols: %s and %s", first, second); + return first; + }))); + + // rewrite and deduplicate specification + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs().stream() + .map(this::map) + .collect(toImmutableList()), + Optional.of(source), + node.isPruneWhenEmpty(), + newPassThroughSpecifications.build(), + newRequiredVariables, + newMarkerVariables, + newSpecification.map(SpecificationWithPreSortedPrefix::getSpecification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::getPreSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), + node.getHandle()); + } + private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source) { return new PartitioningScheme(translateVariable(scheme.getPartitioning(), this::map), @@ -437,6 +525,25 @@ private List mapAndDistinctVariable(List newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getOrderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getPreSorted).orElse(preSorted)); + } + + DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) + { + return new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + specification.getOrderingScheme().map(this::map)); + } + public static SymbolMapper.Builder builder(WarningCollector warningCollector) { return new Builder(warningCollector); @@ -468,4 +575,48 @@ public void put(VariableReferenceExpression from, VariableReferenceExpression to mappingsBuilder.put(from, to); } } + + private static class OrderingSchemeWithPreSortedPrefix + { + private final OrderingScheme orderingScheme; + private final int preSorted; + + public OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + + public OrderingScheme getOrderingScheme() + { + return orderingScheme; + } + + public int getPreSorted() + { + return preSorted; + } + } + + private static class SpecificationWithPreSortedPrefix + { + private final DataOrganizationSpecification specification; + private final int preSorted; + + public SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + + public DataOrganizationSpecification getSpecification() + { + return specification; + } + + public int getPreSorted() + { + return preSorted; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index f0a6f30a8db29..c0c2961114fe8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -78,6 +78,7 @@ import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -86,6 +87,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -161,6 +163,11 @@ private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManag this.warningCollector = warningCollector; } + public Map getMapping() + { + return mapping; + } + @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { @@ -500,18 +507,91 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont @Override public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) { + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap()) + .orElseGet(HashMap::new); + + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); + + List newProperOutputs = node.getOutputVariables().stream() + .map(mapper::map) + .collect(toImmutableList()); + + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode newSource = node.getSources().get(i).accept(this, context); + newSources.add(newSource); + + SymbolMapper inputMapper = new SymbolMapper(new HashMap<>(), warningCollector); + + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + + Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + TableFunctionNode.PassThroughSpecification newPassThroughSpecification = new TableFunctionNode.PassThroughSpecification( + properties.getPassThroughSpecification().isDeclaredAsPassThrough(), + properties.getPassThroughSpecification().getColumns().stream() + .map(column -> new TableFunctionNode.PassThroughColumn( + inputMapper.map(column.getOutputVariables()), + column.isPartitioningColumn())) + .collect(toImmutableList())); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( + properties.getArgumentName(), + properties.isRowSemantics(), + properties.isPruneWhenEmpty(), + newPassThroughSpecification, + inputMapper.map(properties.getRequiredColumns()), + newSpecification)); + } + return new TableFunctionNode( - node.getSourceLocation(), node.getId(), - Optional.empty(), node.getName(), node.getArguments(), - node.getOutputVariables(), - node.getSources(), - node.getTableArgumentProperties(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), node.getHandle()); } + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap()) + .orElseGet(HashMap::new); + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); + + TableFunctionProcessorNode rewrittenTableFunctionProcessor = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + mapper.map(node.getProperOutputs()), + Optional.empty(), + node.isPruneWhenEmpty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()); + + return rewrittenTableFunctionProcessor; + } + + PlanNode rewrittenSource = node.getSource().get().accept(this, context); + Map mappings = ((Rewriter) context.getNodeRewriter()).getMapping(); + SymbolMapper mapper = new SymbolMapper(mappings, types, warningCollector); + + return mapper.map(node, rewrittenSource); + } + @Override public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index a3e3ec3dc0d2c..6bfa05d32a9a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -147,4 +147,9 @@ public R visitTableFunction(TableFunctionNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index a8af658030818..1db8a8b817eb5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -245,6 +245,11 @@ public static Pattern tableFunction() return typeOf(TableFunctionNode.class); } + public static Pattern tableFunctionProcessor() + { + return typeOf(TableFunctionProcessorNode.class); + } + public static Pattern rowNumber() { return typeOf(RowNumberNode.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 22d4f18e42ff9..f87c1a1bba5c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -61,6 +61,11 @@ public C get() return userContext; } + public SimplePlanRewriter getNodeRewriter() + { + return nodeRewriter; + } + /** * Invoke the rewrite logic recursively on children of the given node and swap it * out with an identical copy with the rewritten children diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 97892523498c0..8838e82b48c91 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -22,13 +22,17 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @Immutable @@ -40,6 +44,7 @@ public class TableFunctionNode private final List outputVariables; private final List sources; private final List tableArgumentProperties; + private final List> copartitioningLists; private final TableFunctionHandle handle; @JsonCreator @@ -50,9 +55,10 @@ public TableFunctionNode( @JsonProperty("outputVariables") List outputVariables, @JsonProperty("sources") List sources, @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, @JsonProperty("handle") TableFunctionHandle handle) { - this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, handle); + this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public TableFunctionNode( @@ -64,14 +70,18 @@ public TableFunctionNode( List outputVariables, List sources, List tableArgumentProperties, + List> copartitioningLists, TableFunctionHandle handle) { super(sourceLocation, id, statsEquivalentPlanNode); this.name = requireNonNull(name, "name is null"); - this.arguments = requireNonNull(arguments, "arguments is null"); - this.outputVariables = requireNonNull(outputVariables, "outputVariables is null"); - this.sources = requireNonNull(sources, "sources is null"); - this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.outputVariables = ImmutableList.copyOf(outputVariables); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); this.handle = requireNonNull(handle, "handle is null"); } @@ -87,8 +97,23 @@ public Map getArguments() return arguments; } - @JsonProperty + @Override public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + variables.addAll(outputVariables); + + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + public List getProperOutputs() { return outputVariables; } @@ -99,6 +124,12 @@ public List getTableArgumentProperties() return tableArgumentProperties; } + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -122,35 +153,47 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { checkArgument(sources.size() == newSources.size(), "wrong number of new children"); - return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, handle); + return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, copartitioningLists, handle); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, handle); + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public static class TableArgumentProperties { + private final String argumentName; private final boolean rowSemantics; private final boolean pruneWhenEmpty; - private final boolean passThroughColumns; + private final PassThroughSpecification passThroughSpecification; + private final List requiredColumns; private final Optional specification; @JsonCreator public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, - @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, + @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughColumns = passThroughColumns; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); + this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + @JsonProperty public boolean isRowSemantics() { @@ -164,15 +207,83 @@ public boolean isPruneWhenEmpty() } @JsonProperty - public boolean isPassThroughColumns() + public PassThroughSpecification getPassThroughSpecification() + { + return passThroughSpecification; + } + + @JsonProperty + public List getRequiredColumns() { - return passThroughColumns; + return requiredColumns; } @JsonProperty - public Optional specification() + public Optional getSpecification() { return specification; } } + + /** + * Specifies how columns from source tables are passed through to the output of a table function. + * This class manages both explicitly declared pass-through columns and partitioning columns + * that must be preserved in the output. + */ + public static class PassThroughSpecification + { + private final boolean declaredAsPassThrough; + private final List columns; + + @JsonCreator + public PassThroughSpecification( + @JsonProperty("declaredAsPassThrough") boolean declaredAsPassThrough, + @JsonProperty("columns") List columns) + { + this.declaredAsPassThrough = declaredAsPassThrough; + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + checkArgument( + declaredAsPassThrough || this.columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + + @JsonProperty + public boolean isDeclaredAsPassThrough() + { + return declaredAsPassThrough; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + } + + public static class PassThroughColumn + { + private final VariableReferenceExpression outputVariables; + private final boolean isPartitioningColumn; + + @JsonCreator + public PassThroughColumn( + @JsonProperty("outputVariables") VariableReferenceExpression outputVariables, + @JsonProperty("partitioningColumn") boolean isPartitioningColumn) + { + this.outputVariables = requireNonNull(outputVariables, "symbol is null"); + this.isPartitioningColumn = isPartitioningColumn; + } + + @JsonProperty + public VariableReferenceExpression getOutputVariables() + { + return outputVariables; + } + + @JsonProperty + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 0000000000000..851f776a2c90f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends InternalPlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final Optional source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredVariables; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + // + // Example: + // Given two input tables T1(a,b) PARTITION BY a and T2(c, d) PARTITION BY c + // T1 partitions: T2 partitions: + // a | b c | d + // ---+--- ---+--- + // 1 | 10 5 | 50 + // 1 | 20 5 | 60 + // 1 | 30 6 | 90 + // 2 | 40 6 | 100 + // 2 | 50 6 | 110 + // + // TransformTableFunctionToTableFunctionProcessor creates a join that produces a cartesian product of partitions from each table, resulting in 4 partitions: + // + // Partition (a=1, c=5): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 1 | 10 | 1 | 5 | 50 | 1 (row 1 from both partitions) + // 1 | 20 | 2 | 5 | 60 | 2 (row 2 from both partitions) + // 1 | 30 | 3 | 5 | 50 | null (filler row for T2, real row 3 from T1) + // + // Partition (a=1, c=6): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 1 | 10 | 1 | 6 | 90 | 1 (row 1 from both partitions) + // 1 | 20 | 2 | 6 | 100 | 2 (row 2 from both partitions) + // 1 | 30 | 3 | 6 | 110 | 3 (row 3 from both partitions) + // + // Partition (a=2, c=5): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 2 | 40 | 1 | 5 | 50 | 1 (row 1 from both partitions) + // 2 | 50 | 2 | 5 | 60 | 2 (row 2 from both partitions) + // + // Partition (a=2, c=6): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 2 | 40 | 1 | 6 | 90 | 1 (row 1 from both partitions) + // 2 | 50 | 2 | 6 | 100 | 2 (row 2 from both partitions) + // 2 | 40 | null | 6 | 110 | 3 (filler row for T1, real row 3 from T2) + // + // markerVariables map: + // { + // VariableReferenceExpression(a) -> VariableReferenceExpression(marker_1), + // VariableReferenceExpression(b) -> VariableReferenceExpression(marker_1), + // VariableReferenceExpression(c) -> VariableReferenceExpression(marker_2), + // VariableReferenceExpression(d) -> VariableReferenceExpression(marker_2) + // } + // + // When marker_1 is null, columns a and b should not be processed or passed-through. + // When marker_2 is null, columns c and d should not be processed or passed-through. + + private final Optional> markerVariables; + + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredVariables") List> requiredVariables, + @JsonProperty("markerVariables") Optional> markerVariables, + @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(Optional.empty(), id, Optional.empty()); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredVariables = requiredVariables.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerVariables = markerVariables.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public Optional getSource() + { + return source; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredVariables() + { + return requiredVariables; + } + + @JsonProperty + public Optional> getMarkerVariables() + { + return markerVariables; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return source.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return this; + } + + @Override + public PlanNode replaceChildren(List newSources) + { + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredVariables, markerVariables, specification, prePartitioned, preSorted, hashSymbol, handle); + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 93c049ba781ec..c35f28525bb09 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -35,6 +35,9 @@ import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -96,13 +99,14 @@ import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.MergeProcessorNode; import com.facebook.presto.sql.planner.plan.MergeWriterNode; -import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -114,6 +118,7 @@ import com.google.common.base.Functions; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; @@ -122,11 +127,13 @@ import io.airlift.slice.Slice; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -139,6 +146,7 @@ import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters; import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -152,10 +160,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.toList; public class PlanPrinter @@ -1356,13 +1366,6 @@ public Void visitLateralJoin(LateralJoinNode node, Void context) return processChildren(node, context); } - @Override - public Void visitOffset(OffsetNode node, Void context) - { - addNode(node, "Offset", format("[%s]", node.getCount())); - return processChildren(node, context); - } - @Override public Void visitTableFunction(TableFunctionNode node, Void context) { @@ -1371,9 +1374,177 @@ public Void visitTableFunction(TableFunctionNode node, Void context) "TableFunction", node.getName()); - checkArgument( - node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), - "Table or descriptor arguments are not yet supported in PlanPrinter"); + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); + + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableArgumentProperties::getArgumentName, identity())); + + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetailsLine(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetailsLine(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", ", "Co-partition: [", "] "))); + } + } + + processChildren(node, context); + + return null; + } + + private String formatArgument(String argumentName, Argument argument, Map tableArguments) + { + if (argument instanceof ScalarArgument) { + ScalarArgument scalarArgument = (ScalarArgument) argument; + return formatScalarArgument(argumentName, scalarArgument); + } + if (argument instanceof DescriptorArgument) { + DescriptorArgument descriptorArgument = (DescriptorArgument) argument; + return formatDescriptorArgument(argumentName, descriptorArgument); + } + else { + TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + argument.getValue()); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow(() -> new IllegalStateException("Missing descriptor")).getFields().stream() + .map(field -> field.getName() + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(Collectors.joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + List properties = new ArrayList<>(); + + if (argumentProperties.isRowSemantics()) { + properties.add("row semantics "); + } + argumentProperties.getSpecification().ifPresent(specification -> { + StringBuilder specificationBuilder = new StringBuilder(); + specificationBuilder + .append("partition by: [") + .append(Joiner.on(", ").join(specification.getPartitionBy())) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + specificationBuilder + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + properties.add(specificationBuilder.toString()); + }); + + properties.add("required columns: [" + + Joiner.on(", ").join(argumentProperties.getRequiredColumns()) + "]"); + + if (argumentProperties.isPruneWhenEmpty()) { + properties.add("prune when empty"); + } + + if (argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + properties.add("pass through columns"); + } + + return format("%s => TableArgument{%s}", argumentName, Joiner.on(", ").join(properties)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme) + { + return formatCollection(orderingScheme.getOrderByVariables(), variable -> variable + " " + orderingScheme.getOrdering(variable)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme, int preSortedOrderPrefix) + { + List orderBy = Stream.concat( + orderingScheme.getOrderByVariables().stream() + .limit(preSortedOrderPrefix) + .map(variable -> "<" + variable + " " + orderingScheme.getOrdering(variable) + ">"), + orderingScheme.getOrderByVariables().stream() + .skip(preSortedOrderPrefix) + .map(variable -> variable + " " + orderingScheme.getOrdering(variable))) + .collect(toImmutableList()); + return formatCollection(orderBy, Objects::toString); + } + + public String formatCollection(Collection collection, Function formatter) + { + return collection.stream() + .map(formatter) + .collect(Collectors.joining(", ", "[", "]")); + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); + + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(node.getProperOutputs()))); + + String specs = node.getPassThroughSpecifications().stream() + .map(spec -> spec.getColumns().stream() + .map(col -> col.getOutputVariables().toString()) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ")); + descriptor.put("passThroughSymbols", format("[%s]", specs)); + + String requiredSymbols = node.getRequiredVariables().stream() + .map(vars -> vars.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ", "[", "]")); + descriptor.put("requiredSymbols", format("[%s]", requiredSymbols)); + + node.getSpecification().ifPresent(specification -> { + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(prePartitioned.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(notPrePartitioned)); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); + }); + + addNode(node, "TableFunctionProcessorNode" + descriptor.build()); return processChildren(node, context); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index bf93c5d23388f..1a88f259c882e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -74,6 +74,8 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -88,6 +90,7 @@ import static com.facebook.presto.spi.plan.JoinNode.checkLeftOutputVariablesBeforeRight; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables; import static com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -124,6 +127,117 @@ public Void visitPlan(PlanNode node, Set boundVaria @Override public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getRequiredColumns(), + source.getOutputVariables()); + argumentProperties.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + Set passThroughVariable = argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughVariable, + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + passThroughVariable, + source.getOutputVariables()); + } + return null; + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundVariables) + { + if (!node.getSource().isPresent()) { + return null; + } + + PlanNode source = node.getSource().get(); + source.accept(this, boundVariables); + + Set inputs = createInputs(source, boundVariables); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputVariables()); + + Set requiredSymbols = node.getRequiredVariables().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputVariables()); + + node.getMarkerVariables().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputVariables()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputVariables()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + return null; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index aaad7753c672f..db8e814b45f72 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -125,6 +125,7 @@ import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.operator.index.IndexJoinLookupStats; +import com.facebook.presto.operator.table.ExcludeColumns; import com.facebook.presto.server.NodeStatusNotificationManager; import com.facebook.presto.server.PluginManager; import com.facebook.presto.server.PluginManagerConfig; @@ -527,7 +528,10 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ColumnPropertiesSystemTable(transactionManager, metadata), new AnalyzePropertiesSystemTable(transactionManager, metadata), new TransactionsSystemTable(metadata.getFunctionAndTypeManager(), transactionManager)), - ImmutableSet.of()); + ImmutableSet.of(), + ImmutableSet.of(new ExcludeColumns.ExcludeColumnsFunction()), + null, + getFunctionAndTypeManager()); BuiltInQueryAnalyzer queryAnalyzer = new BuiltInQueryAnalyzer(metadata, sqlParser, accessControl, Optional.empty(), metadataExtractorExecutor); BuiltInAnalyzerProvider analyzerProvider = new BuiltInAnalyzerProvider(queryAnalyzer); @@ -779,7 +783,8 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java index f6c1a0b9b1405..26fa51563aab9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java @@ -114,6 +114,10 @@ default void loadSessionPropertyProvider(String sessionPropertyProviderName, Map throw new UnsupportedOperationException(); } + default void loadTVFProvider(String tvfProviderName) + { + throw new UnsupportedOperationException(); + } Lock getExclusiveLock(); default void loadTypeManager(String typeManagerName) diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index da4a7fc2cd4e0..5368f400f4646 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -70,6 +70,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -136,6 +138,8 @@ private enum NodeType EXPLAIN_ANALYZE, UPDATE, MERGE, + TABLE_FUNCTION, + TABLE_FUNCTION_PROCESSOR } private static final Map NODE_COLORS = immutableEnumMap(ImmutableMap.builder() @@ -168,6 +172,8 @@ private enum NodeType .put(NodeType.EXPLAIN_ANALYZE, "cadetblue1") .put(NodeType.UPDATE, "blue") .put(NodeType.MERGE, "lightblue") + .put(NodeType.TABLE_FUNCTION, "mediumorchid3") + .put(NodeType.TABLE_FUNCTION_PROCESSOR, "steelblue3") .build()); static { @@ -409,6 +415,24 @@ public Void visitWindow(WindowNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + printNode(node, "Table Function Processor", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + if (node.getSource().isPresent()) { + node.getSource().get().accept(this, context); + } + return null; + } + + @Override + public Void visitTableFunction(TableFunctionNode node, Void context) + { + printNode(node, "Table Function Node", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + node.getSources().stream().map(source -> source.accept(this, context)); + return null; + } + @Override public Void visitRowNumber(RowNumberNode node, Void context) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java index 2674c87d28cc6..f64d51a0c10f3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java @@ -47,8 +47,11 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.transaction.IsolationLevel; @@ -57,6 +60,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import java.util.Collections; import java.util.List; @@ -69,6 +73,7 @@ import java.util.stream.IntStream; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.connector.tvf.TestTVFConnectorFactory.TestTVFConnector.TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -84,6 +89,10 @@ public class TestTVFConnectorFactory private final Supplier getTableStatistics; private final ApplyTableFunction applyTableFunction; private final Set tableFunctions; + private final Function tableFunctionProcessorProvider; + private final TestTvfTableFunctionHandleResolver tableFunctionHandleResolver; + private final TestTvfTableFunctionSplitResolver tableFunctionSplitResolver; + private final Function tableFunctionSplitsSources; private TestTVFConnectorFactory( Function> listSchemaNames, @@ -92,7 +101,11 @@ private TestTVFConnectorFactory( BiFunction> getColumnHandles, Supplier getTableStatistics, ApplyTableFunction applyTableFunction, - Set tableFunctions) + Set tableFunctions, + Function getTableFunctionProcessorProvider, + TestTvfTableFunctionHandleResolver tableFunctionHandleResolver, + TestTvfTableFunctionSplitResolver tableFunctionSplitResolver, + Function tableFunctionSplitsSources) { this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); this.listTables = requireNonNull(listTables, "listTables is null"); @@ -101,6 +114,10 @@ private TestTVFConnectorFactory( this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionHandleResolver = requireNonNull(tableFunctionHandleResolver, "tableFunctionHandleResolver is null"); + this.tableFunctionSplitResolver = requireNonNull(tableFunctionSplitResolver, "tableFunctionSplitResolver is null"); + this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); } @Override @@ -115,10 +132,22 @@ public ConnectorHandleResolver getHandleResolver() return new TestTVFHandleResolver(); } + @Override + public Optional getTableFunctionHandleResolver() + { + return Optional.of(tableFunctionHandleResolver); + } + + @Override + public Optional getTableFunctionSplitResolver() + { + return Optional.of(tableFunctionSplitResolver); + } + @Override public Connector create(String catalogName, Map config, ConnectorContext context) { - return new TestTVFConnector(context, listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions); + return new TestTVFConnector(context, listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionProcessorProvider, tableFunctionSplitsSources); } public static Builder builder() @@ -154,7 +183,9 @@ public static class TestTVFConnector private final BiFunction> getColumnHandles; private final Supplier getTableStatistics; private final ApplyTableFunction applyTableFunction; + private final Function tableFunctionProcessorProvider; private final Set tableFunctions; + private final Function tableFunctionSplitsSources; public TestTVFConnector( ConnectorContext context, @@ -164,7 +195,9 @@ public TestTVFConnector( BiFunction> getColumnHandles, Supplier getTableStatistics, ApplyTableFunction applyTableFunction, - Set tableFunctions) + Set tableFunctions, + Function getTableFunctionProcessorProvider, + Function tableFunctionSplitsSources) { this.context = requireNonNull(context, "context is null"); this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); @@ -174,6 +207,8 @@ public TestTVFConnector( this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); } @Override @@ -220,7 +255,15 @@ public ConnectorSplitManager getSplitManager() @Override public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) { - return new FixedSplitSource(Collections.singleton(TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT)); + return new FixedSplitSource(Collections.singleton(TEST_TVF_CONNECTOR_SPLIT)); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle functionHandle) + { + ConnectorSplitSource splits = tableFunctionSplitsSources.apply(functionHandle); + return requireNonNull(splits, "missing ConnectorSplitSource for table function handle " + + functionHandle.getClass().getSimpleName()); } }; } @@ -243,6 +286,12 @@ public Set getTableFunctions() return tableFunctions; } + @Override + public Function getTableFunctionProcessorProvider() + { + return tableFunctionProcessorProvider; + } + private class TestTVFConnectorMetadata implements ConnectorMetadata { @@ -382,6 +431,40 @@ private Map getColumnIndexes(SchemaTableName tableName) } } + public static class TestTvfTableFunctionHandleResolver + implements TableFunctionHandleResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionHandleClasses() + { + return handles; + } + + public void addTableFunctionHandle(Class tableFunctionHandleClass) + { + handles.add(tableFunctionHandleClass); + } + } + + public static class TestTvfTableFunctionSplitResolver + implements TableFunctionSplitResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionSplitClasses() + { + return handles; + } + + public void addSplitClass(Class splitClass) + { + handles.add(splitClass); + } + } + public static final class Builder { private Function> listSchemaNames = (session) -> ImmutableList.of(); @@ -396,6 +479,10 @@ public static final class Builder private Supplier getTableStatistics = TableStatistics::empty; private ApplyTableFunction applyTableFunction = (session, handle) -> Optional.empty(); private Set tableFunctions = ImmutableSet.of(); + private Function tableFunctionProcessorProvider = handle -> null; + private final TestTvfTableFunctionHandleResolver tableFunctionHandleResolver = new TestTvfTableFunctionHandleResolver(); + private TestTvfTableFunctionSplitResolver tableFunctionSplitResolver = new TestTvfTableFunctionSplitResolver(); + private Function tableFunctionSplitsSources = handle -> null; public Builder withListSchemaNames(Function> listSchemaNames) { @@ -439,14 +526,38 @@ public Builder withTableFunctions(Iterable tableFunction return this; } + public Builder withTableFunctionProcessorProvider(Function tableFunctionProcessorProvider) + { + this.tableFunctionProcessorProvider = tableFunctionProcessorProvider; + return this; + } + + public Builder withTableFunctionResolver(Class tableFunctionHandleclass) + { + this.tableFunctionHandleResolver.addTableFunctionHandle(tableFunctionHandleclass); + return this; + } + + public Builder withTableFunctionSplitResolver(Class splitClass) + { + this.tableFunctionSplitResolver.addSplitClass(splitClass); + return this; + } + public TestTVFConnectorFactory build() { - return new TestTVFConnectorFactory(listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions); + return new TestTVFConnectorFactory(listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionProcessorProvider, tableFunctionHandleResolver, tableFunctionSplitResolver, tableFunctionSplitsSources); } private static T notSupported() { throw new UnsupportedOperationException(); } + + public Builder withTableFunctionSplitSource(Function sourceProvider) + { + tableFunctionSplitsSources = requireNonNull(sourceProvider, "sourceProvider is null"); + return this; + } } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java index 96373d826b50a..004b1339b48db 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -13,39 +13,68 @@ */ package com.facebook.presto.connector.tvf; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; import com.facebook.presto.spi.function.table.Argument; import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.DescribedTableReturnTypeSpecification; import com.facebook.presto.spi.function.table.Descriptor; import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; -import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.facebook.presto.spi.function.table.GenericTableReturnTypeSpecification; import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; import com.facebook.presto.spi.function.table.TableArgumentSpecification; import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; +import java.util.stream.IntStream; +import static com.facebook.presto.common.Utils.checkArgument; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; -import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; -import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.OnlyPassThrough.ONLY_PASS_THROUGH; +import static com.facebook.presto.spi.function.table.GenericTableReturnTypeSpecification.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.OnlyPassThroughReturnTypeSpecification.ONLY_PASS_THROUGH; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInput; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class TestingTableFunctions @@ -67,18 +96,17 @@ public class TestingTableFunctions public static class TestConnectorTableFunction extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION = "test_function"; - + private static final String FUNCTION_NAME = "test_function"; public TestConnectorTableFunction() { - super(SCHEMA_NAME, TEST_FUNCTION, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), GenericTableReturnTypeSpecification.GENERIC_TABLE); } @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { return TableFunctionAnalysis.builder() - .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, TEST_FUNCTION))) + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("c1", Optional.of(BOOLEAN))))) .build(); } @@ -87,11 +115,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TestConnectorTableFunction2 extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION_2 = "test_function2"; - + private static final String FUNCTION_NAME = "test_function2"; public TestConnectorTableFunction2() { - super(SCHEMA_NAME, TEST_FUNCTION_2, ImmutableList.of(), ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ONLY_PASS_THROUGH); } @Override @@ -104,11 +131,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class NullArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String NULL_ARGUMENTS_FUNCTION = "null_arguments_function"; - + private static final String FUNCTION_NAME = "null_arguments_function"; public NullArgumentsTableFunction() { - super(SCHEMA_NAME, NULL_ARGUMENTS_FUNCTION, null, ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, null, ONLY_PASS_THROUGH); } @Override @@ -121,12 +147,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DuplicateArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String DUPLICATE_ARGUMENTS_FUNCTION = "duplicate_arguments_function"; + private static final String FUNCTION_NAME = "duplicate_arguments_function"; public DuplicateArgumentsTableFunction() { super( SCHEMA_NAME, - DUPLICATE_ARGUMENTS_FUNCTION, + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder().name("a").type(INTEGER).build(), ScalarArgumentSpecification.builder().name("a").type(INTEGER).build()), @@ -143,12 +169,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MultipleRSTableFunction extends AbstractConnectorTableFunction { - private static final String MULTIPLE_SOURCES_FUNCTION = "multiple_sources_function"; + private static final String FUNCTION_NAME = "multiple_sources_function"; public MultipleRSTableFunction() { super( SCHEMA_NAME, - MULTIPLE_SOURCES_FUNCTION, + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder().name("t").rowSemantics().build(), TableArgumentSpecification.builder().name("t2").rowSemantics().build()), ONLY_PASS_THROUGH); @@ -172,7 +198,6 @@ public static class SimpleTableFunction { private static final String FUNCTION_NAME = "simple_table_function"; private static final String TABLE_NAME = "simple_table"; - public SimpleTableFunction() { super( @@ -227,11 +252,12 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TwoScalarArgumentsFunction extends AbstractConnectorTableFunction { + private static final String FUNCTION_NAME = "two_scalar_arguments_function"; public TwoScalarArgumentsFunction() { super( SCHEMA_NAME, - "two_arguments_function", + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder() .name("TEXT") @@ -256,7 +282,6 @@ public static class TableArgumentFunction extends AbstractConnectorTableFunction { public static final String FUNCTION_NAME = "table_argument_function"; - public TableArgumentFunction() { super( @@ -284,11 +309,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DescriptorArgumentFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "descriptor_argument_function"; public DescriptorArgumentFunction() { super( SCHEMA_NAME, - "descriptor_argument_function", + FUNCTION_NAME, ImmutableList.of( DescriptorArgumentSpecification.builder() .name("SCHEMA") @@ -327,11 +353,16 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TestingTableFunctionHandle implements ConnectorTableFunctionHandle { + private final TestTVFConnectorTableHandle tableHandle; private final SchemaFunctionName schemaFunctionName; @JsonCreator public TestingTableFunctionHandle(@JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName) { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(COLUMN_NAME, BOOLEAN))), + TupleDomain.all()); this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); } @@ -340,16 +371,22 @@ public SchemaFunctionName getSchemaFunctionName() { return schemaFunctionName; } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } } public static class TableArgumentRowSemanticsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "table_argument_row_semantics_function"; public TableArgumentRowSemanticsFunction() { super( SCHEMA_NAME, - "table_argument_row_semantics_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -372,17 +409,20 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TwoTableArgumentsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "two_table_arguments_function"; public TwoTableArgumentsFunction() { super( SCHEMA_NAME, - "two_table_arguments_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT1") + .keepWhenEmpty() .build(), TableArgumentSpecification.builder() .name("INPUT2") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -402,11 +442,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class OnlyPassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "only_pass_through_function"; public OnlyPassThroughFunction() { super( SCHEMA_NAME, - "only_pass_through_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -425,13 +466,14 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MonomorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "monomorphic_static_return_type_function"; public MonomorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "monomorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(), - new DescribedTable(Descriptor.descriptor( + new DescribedTableReturnTypeSpecification(Descriptor.descriptor( ImmutableList.of("a", "b"), ImmutableList.of(BOOLEAN, INTEGER)))); } @@ -448,15 +490,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PolymorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "polymorphic_static_return_type_function"; public PolymorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "polymorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .build()), - new DescribedTable(Descriptor.descriptor( + new DescribedTableReturnTypeSpecification(Descriptor.descriptor( ImmutableList.of("a", "b"), ImmutableList.of(BOOLEAN, INTEGER)))); } @@ -471,16 +514,18 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "pass_through_function"; public PassThroughFunction() { super( SCHEMA_NAME, - "pass_through_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .passThroughColumns() + .keepWhenEmpty() .build()), - new DescribedTable(Descriptor.descriptor( + new DescribedTableReturnTypeSpecification(Descriptor.descriptor( ImmutableList.of("x"), ImmutableList.of(BOOLEAN)))); } @@ -495,14 +540,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class RequiredColumnsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "required_columns_function"; public RequiredColumnsFunction() { super( SCHEMA_NAME, - "required_columns_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -517,4 +564,915 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact .build(); } } + + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "different_arguments_function"; + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } + + // for testing execution by operator + + public static class IdentityFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "identity_function"; + public IdentityFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + List inputColumns = ((TableArgument) arguments.get("INPUT")).getRowType().getFields(); + Descriptor returnedType = new Descriptor(inputColumns.stream() + .map(field -> new Descriptor.Field(field.getName().orElse("anonymous_column"), Optional.of(field.getType()))) + .collect(toImmutableList())); + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(returnedType) + .requiredColumns("INPUT", IntStream.range(0, inputColumns.size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class IdentityFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + Optional inputPage = getOnlyElement(input); + return inputPage.map(TableFunctionProcessorState.Processed::usedInputAndProduced).orElseThrow(NoSuchElementException::new); + }; + } + } + } + + public static class IdentityPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "identity_pass_through_function"; + public IdentityPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class IdentityPassThroughFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new IdentityPassThroughFunctionDataProcessor(); + } + } + + public static class IdentityPassThroughFunctionDataProcessor + implements TableFunctionDataProcessor + { + private long processedPositions; // stateful + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + return usedInputAndProduced(new Page(builder.build())); + } + } + } + + public static class RepeatFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "repeat"; + public RepeatFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(2L) + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new RepeatFunctionHandle((long) count.getValue())) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class RepeatFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long count; + + @JsonCreator + public RepeatFunctionHandle(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class RepeatFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new RepeatFunctionDataProcessor(((RepeatFunctionHandle) handle).getCount()); + } + } + + public static class RepeatFunctionDataProcessor + implements TableFunctionDataProcessor + { + private final long count; + + // stateful + private long processedPositions; + private long processedRounds; + private Block indexes; + boolean usedData; + + public RepeatFunctionDataProcessor(long count) + { + this.count = count; + } + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + if (processedRounds < count && indexes != null) { + processedRounds++; + return produced(new Page(indexes)); + } + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + if (processedRounds == 0) { + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + indexes = builder.build(); + usedData = true; + } + else { + usedData = false; + } + processedRounds++; + + Page result = new Page(indexes); + + if (processedRounds == count) { + processedRounds = 0; + indexes = null; + } + + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + } + } + + public static class EmptyOutputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_output"; + public EmptyOutputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputDataProcessor(); + } + } + + // returns an empty Page (one column, zero rows) for each Page of input + private static class EmptyOutputDataProcessor + implements TableFunctionDataProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class EmptyOutputWithPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_output_with_pass_through"; + public EmptyOutputWithPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .passThroughColumns() + .build()), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputWithPassThroughProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputWithPassThroughDataProcessor(); + } + } + + // returns an empty Page (one proper column and pass-through, zero rows) for each Page of input + private static class EmptyOutputWithPassThroughDataProcessor + implements TableFunctionDataProcessor + { + // one proper channel, and one pass-through index channel + private static final Page EMPTY_PAGE = new Page( + BOOLEAN.createBlockBuilder(null, 0).build(), + BIGINT.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class TestInputsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_inputs_function"; + public TestInputsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT_1") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_4") + .keepWhenEmpty() + .build()), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT_1", IntStream.range(0, ((TableArgument) arguments.get("INPUT_1")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_2", IntStream.range(0, ((TableArgument) arguments.get("INPUT_2")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_3", IntStream.range(0, ((TableArgument) arguments.get("INPUT_3")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_4", IntStream.range(0, ((TableArgument) arguments.get("INPUT_4")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputsFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder resultBuilder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(resultBuilder, true); + + Page result = new Page(resultBuilder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class PassThroughInputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "pass_through"; + public PassThroughInputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .passThroughColumns() + .keepWhenEmpty() + .build()), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of( + new Descriptor.Field("input_1_present", Optional.of(BOOLEAN)), + new Descriptor.Field("input_2_present", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .build(); + } + + public static class PassThroughInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new PassThroughInputDataProcessor(); + } + } + + private static class PassThroughInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean input1Present; + private boolean input2Present; + private int input1EndIndex; + private int input2EndIndex; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + + // proper column input_1_present + BlockBuilder input1Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input1Builder, input1Present); + + // proper column input_2_present + BlockBuilder input2Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input2Builder, input2Present); + + // pass-through index for input_1 + BlockBuilder input1PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input1Present) { + input1PassThroughBuilder.writeLong(input1EndIndex - 1); + } + else { + input1PassThroughBuilder.appendNull(); + } + + // pass-through index for input_2 + BlockBuilder input2PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input2Present) { + input2PassThroughBuilder.writeLong(input2EndIndex - 1); + } + else { + input2PassThroughBuilder.appendNull(); + } + + return produced(new Page(input1Builder.build(), input2Builder.build(), input1PassThroughBuilder.build(), input2PassThroughBuilder.build())); + } + input.get(0).ifPresent(page -> { + input1Present = true; + input1EndIndex += page.getPositionCount(); + }); + input.get(1).ifPresent(page -> { + input2Present = true; + input2EndIndex += page.getPositionCount(); + }); + return usedInput(); + } + } + } + + public static class TestInputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_input"; + public TestInputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of(new Descriptor.Field("got_input", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new TestInputDataProcessor(); + } + } + + private static class TestInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean processorGotInput; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, processorGotInput); + return produced(new Page(builder.build())); + } + processorGotInput = true; + return usedInput(); + } + } + } + + public static class TestSingleInputRowSemanticsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_single_input_function"; + public TestSingleInputRowSemanticsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT") + .build()), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestSingleInputFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, true); + Page result = new Page(builder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class ConstantFunction + extends AbstractConnectorTableFunction + { + static final String FUNCTION_NAME = "constant"; + public ConstantFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name("VALUE") + .type(INTEGER) + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(1L) + .build()), + new DescribedTableReturnTypeSpecification(Descriptor.descriptor( + ImmutableList.of("constant_column"), + ImmutableList.of(INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new ConstantFunctionHandle((Long) ((ScalarArgument) arguments.get("VALUE")).getValue(), (long) count.getValue())) + .build(); + } + + public static class ConstantFunctionHandle + implements ConnectorTableFunctionHandle + { + private final Long value; + private final long count; + + @JsonCreator + public ConstantFunctionHandle(@JsonProperty("value") Long value, @JsonProperty("count") long count) + { + this.value = value; + this.count = count; + } + + @JsonProperty + public Long getValue() + { + return value; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class ConstantFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new ConstantFunctionProcessor(((ConstantFunctionHandle) handle).getValue()); + } + } + + public static class ConstantFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final int PAGE_SIZE = 1000; + + private final Long value; + + private long fullPagesCount; + private long processedPages; + private int reminder; + private Block block; + + public ConstantFunctionProcessor(Long value) + { + this.value = value; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + boolean usedData = false; + + if (split != null) { + long count = ((ConstantFunctionSplit) split).getCount(); + this.fullPagesCount = count / PAGE_SIZE; + this.reminder = toIntExact(count % PAGE_SIZE); + if (fullPagesCount > 0) { + BlockBuilder builder = INTEGER.createBlockBuilder(null, PAGE_SIZE); + if (value == null) { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + else { + BlockBuilder builder = INTEGER.createBlockBuilder(null, reminder); + if (value == null) { + for (int i = 0; i < reminder; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < reminder; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + usedData = true; + } + + if (processedPages < fullPagesCount) { + processedPages++; + Page result = new Page(block); + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + if (reminder > 0) { + Page result = new Page(block.getRegion(0, toIntExact(reminder))); + reminder = 0; + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + return FINISHED; + } + } + + public static ConnectorSplitSource getConstantFunctionSplitSource(ConstantFunctionHandle handle) + { + long splitSize = ConstantFunctionSplit.DEFAULT_SPLIT_SIZE; + ImmutableList.Builder splits = ImmutableList.builder(); + for (long i = 0; i < handle.getCount() / splitSize; i++) { + splits.add(new ConstantFunctionSplit(splitSize)); + } + long remainingSize = handle.getCount() % splitSize; + if (remainingSize > 0) { + splits.add(new ConstantFunctionSplit(remainingSize)); + } + return new FixedSplitSource(splits.build()); + } + + public static final class ConstantFunctionSplit + implements ConnectorSplit + { + private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(ConstantFunctionSplit.class).instanceSize()); + public static final int DEFAULT_SPLIT_SIZE = 5500; + + private final long count; + + @JsonCreator + public ConstantFunctionSplit(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return count; + } + } + } + + public static class EmptySourceFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_source"; + public EmptySourceFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(), + new DescribedTableReturnTypeSpecification(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .build(); + } + + public static class EmptySourceFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptySourceFunctionProcessor(); + } + } + + public static class EmptySourceFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split == null) { + return FINISHED; + } + + return usedInputAndProduced(EMPTY_PAGE); + } + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java index e7670a5af9b5d..18d3903879b4f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.TableFunctionMetadata; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; @@ -29,7 +30,8 @@ import static com.facebook.presto.connector.tvf.TestingTableFunctions.TestConnectorTableFunction2; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.assertions.Assert.assertEquals; -import static org.testng.Assert.assertFalse; +import static com.facebook.presto.testing.assertions.Assert.assertNotNull; +import static com.facebook.presto.testing.assertions.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.expectThrows; @@ -63,19 +65,19 @@ public void testTableFunctionRegistry() assertTrue(ex.getMessage().contains("Table functions already registered for catalog: test"), ex.getMessage()); // Verify table function resolution. - assertTrue(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)).isPresent()); - assertTrue(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION_2)).isPresent()); - assertFalse(testFunctionRegistry.resolve(SESSION, QualifiedName.of("none")).isPresent()); - assertFalse(testFunctionRegistry.resolve(MISMATCH_SESSION, QualifiedName.of("none")).isPresent()); + assertNotNull(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION))); + assertNotNull(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION_2))); + assertNotNull(testFunctionRegistry.resolve(SESSION, QualifiedName.of("none"))); + assertNotNull(testFunctionRegistry.resolve(MISMATCH_SESSION, QualifiedName.of("none"))); // Verify metadata. - TableFunctionMetadata data = testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)).get(); + TableFunctionMetadata data = testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)); assertEquals(data.getConnectorId(), id); assertTrue(data.getFunction() instanceof TestConnectorTableFunction); // Verify the removal of table functions. testFunctionRegistry.removeTableFunctions(id); - assertFalse(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)).isPresent()); + assertNull(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION))); // Verify that null arguments table functions cannot be added. ex = expectThrows(NullPointerException.class, () -> testFunctionRegistry.addTableFunctions(id, ImmutableList.of(new NullArgumentsTableFunction()))); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java index 5c28e36be9241..458c86f986203 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java @@ -587,11 +587,6 @@ protected void assertFails(SemanticErrorCode error, String message, @Language("S assertFails(CLIENT_SESSION, error, message, query, false); } - protected void assertFailsExact(SemanticErrorCode error, String message, @Language("SQL") String query) - { - assertFails(CLIENT_SESSION, error, message, query, true); - } - protected void assertFails(Session session, SemanticErrorCode error, @Language("SQL") String query) { assertFails(session, error, Optional.empty(), query); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index a74d59c1e2d03..c3236d3e6627a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -33,6 +33,7 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.regex.Pattern; import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager; import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; @@ -1979,59 +1980,59 @@ public void testTableFunctionNotFound() @Test public void testTableFunctionArguments() { - assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:51: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_arguments_function(1, 2, 3))"); + assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:58: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_scalar_arguments_function(1, 2, 3))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo'))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo'))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', number => 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', number => 1))"); assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, - "line 1:51: All arguments must be passed by name or all must be passed positionally", - "SELECT * FROM TABLE(system.two_arguments_function('foo', number => 1))"); + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', number => 1))"); assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, - "line 1:51: All arguments must be passed by name or all must be passed positionally", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', 1))"); + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', 1))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Duplicate argument name: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', text => 'bar'))"); + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', text => 'bar'))"); // argument names are resolved in the canonical form assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Duplicate argument name: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', TeXt => 'bar'))"); + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', TeXt => 'bar'))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:66: Unexpected argument name: BAR", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'foo', bar => 'bar'))"); + "line 1:73: Unexpected argument name: BAR", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', bar => 'bar'))"); assertFails(TABLE_FUNCTION_MISSING_ARGUMENT, - "line 1:51: Missing argument: TEXT", - "SELECT * FROM TABLE(system.two_arguments_function(number => 1))"); + "line 1:58: Missing argument: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(number => 1))"); } @Test public void testScalarArgument() { - analyze("SELECT * FROM TABLE(system.two_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: Invalid argument NUMBER. Expected expression, got descriptor", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); + "line 1:71: Invalid argument NUMBER. Expected expression, got descriptor", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: 'descriptor' function is not allowed as a table function argument", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); + "line 1:71: 'descriptor' function is not allowed as a table function argument", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:64: Invalid argument NUMBER. Expected expression, got table", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => TABLE(t1)))"); + "line 1:71: Invalid argument NUMBER. Expected expression, got table", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => TABLE(t1)))"); assertFails(EXPRESSION_NOT_CONSTANT, - "line 1:74: Constant expression cannot contain a subquery", - "SELECT * FROM TABLE(system.two_arguments_function(text => 'a', number => (SELECT 1)))"); + "line 1:81: Constant expression cannot contain a subquery", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => (SELECT 1)))"); } @Test @@ -2127,8 +2128,8 @@ public void testDescriptorArgument() { analyze("SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(x integer, y boolean)))"); - assertFailsExact(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, - "line 1:57: Invalid descriptor argument SCHEMA. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'", + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + Pattern.quote("line 1:57: Invalid descriptor argument SCHEMA. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'"), "SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(1 + 2)))"); assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, @@ -2243,10 +2244,10 @@ public void testNullArguments() // the default value for the argument schema is null analyze("SELECT * FROM TABLE(system.descriptor_argument_function())"); - analyze("SELECT * FROM TABLE(system.two_arguments_function(null, null))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(null, null))"); // the default value for the second argument is null - analyze("SELECT * FROM TABLE(system.two_arguments_function('a'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a'))"); } @Test @@ -2258,8 +2259,8 @@ public void testTableFunctionInvocationContext() "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) f(x)"); // per SQL standard, relation alias is required for table function with GENERIC TABLE return type. We don't require it. - analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x)"); - analyze("SELECT * FROM TABLE(system.two_arguments_function('a', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x)"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1))"); // per SQL standard, relation alias is required for table function with statically declared return type, only if the function is polymorphic. // We don't require aliasing polymorphic functions. @@ -2276,7 +2277,7 @@ public void testTableFunctionInvocationContext() // aliased + sampled assertFails(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, "line 1:15: Cannot apply sample to polymorphic table function invocation", - "SELECT * FROM TABLE(system.two_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); + "SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); } @Test @@ -2294,19 +2295,19 @@ public void testTableFunctionAliasing() analyze("SELECT * FROM TABLE(system.table_argument_function(TABLE(t1) t2)) T1(x)"); // the original returned relation type is ("column" : BOOLEAN) - analyze("SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias"); + analyze("SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias"); - analyze("SELECT column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + analyze("SELECT column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); - analyze("SELECT table_alias.column_alias FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + analyze("SELECT table_alias.column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); assertFails(MISSING_ATTRIBUTE, "line 1:8: Column 'column' cannot be resolved", - "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(column_alias)"); + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); assertFails(MISMATCHED_COLUMN_ALIASES, "line 1:20: Column alias list has 3 entries but table function has 1 proper columns", - "SELECT column FROM TABLE(system.two_arguments_function('a', 1)) table_alias(col1, col2, col3)"); + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(col1, col2, col3)"); // the original returned relation type is ("a" : BOOLEAN, "b" : INTEGER) analyze("SELECT column_alias_1, column_alias_2 FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(column_alias_1, column_alias_2)"); @@ -2348,8 +2349,10 @@ public void testTableFunctionRequiredColumns() "Invalid index: 1 of required column from table argument INPUT", "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1)))"); - // table s1.t5 has two columns. The second column is hidden. Table function can require a hidden column. - analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); + // table s1.t5 has two columns. The second column is hidden. Table function cannot require a hidden column. + assertFails(TABLE_FUNCTION_IMPLEMENTATION_ERROR, + "Invalid index: 1 of required column from table argument INPUT", + "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); } @Test diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index a8a69731f4361..46f2f031e858b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -1803,6 +1803,16 @@ public void testOffsetWithLimit() .withAlias("row_num", new RowNumberSymbolMatcher())))))); } + @Test + public void testRewriteExcludeColumnsFunctionToProjection() + { + assertPlan("SELECT *\n" + + "FROM TABLE(system.builtin.exclude_columns(\n" + + " INPUT => TABLE(orders),\n" + + " COLUMNS => DESCRIPTOR(comment)))\n", + output(tableScan("orders"))); + } + private Session noJoinReordering() { return Session.builder(this.getQueryRunner().getDefaultSession()) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..6d236432e1d15 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.connector.tvf.TestTVFConnectorFactory; +import com.facebook.presto.connector.tvf.TestTVFConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DifferentArgumentTypesFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TestingTableFunctionHandle; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoTableArgumentsFunction; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.Descriptor.Field; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.RowNumberSymbolMatcher; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictOutput; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "test"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new TestTVFConnectorPlugin(TestTVFConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new TwoTableArgumentsFunction(), + new DescriptorArgumentFunction(), + new TestingTableFunctions.PassThroughFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Missing columns")))); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "testTVF", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + "SELECT * FROM TABLE(test.system.different_arguments_function(" + + "INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1," + + "INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3," + + "INPUT_2 => TABLE(VALUES 1) t2(c2)," + + "ID => BIGINT '2001'," + + "LAYOUT => DESCRIPTOR (x boolean, y bigint)" + + "COPARTITION (t1, t3))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughVariables(ImmutableSet.of("c1")) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty() + .passThroughVariables(ImmutableSet.of("c3"))) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughVariables(ImmutableSet.of("c2")) + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); + } + + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan("SELECT * FROM TABLE(test.system.two_table_arguments_function(" + + "INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1," + + "INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 " + + "COPARTITION (t1, t2))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("2")))))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.two_scalar_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_scalar_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testPruneTableFunctionColumns() + { + // all table function outputs are referenced with SELECT *, no pruning + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("x", "a", "b"), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols( + ImmutableList.of(ImmutableList.of("a", "b"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'"), "b", expression("BOOLEAN'true'")), values(1))))); + + // no table function outputs are referenced. All pass-through symbols are pruned from the TableFunctionProcessorNode. The unused symbol "b" is pruned from the source values node. + assertPlan("SELECT 'constant' c FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("c"), + strictProject( + ImmutableMap.of("c", expression("VARCHAR'constant'")), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'")), values(1)))))); + } + + @Test + public void testRemoveRedundantTableFunction() + { + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true WHERE false) t(a, b) PRUNE WHEN EMPTY))", + output(values(ImmutableList.of("x", "a", "b")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) PRUNE WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + values(ImmutableList.of("a", "marker_1", "c", "marker_2", "row_number"))))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) PRUNE WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + project( + project( + rowNumber( + builder -> builder.partitionBy(ImmutableList.of()), + project( + ImmutableMap.of("c", expression("INTEGER'2'")), + values(1)) + ).withAlias("input_2_row_number", new RowNumberSymbolMatcher())))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index d882fe7e54a5f..7b41ae0b02bf6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -658,6 +658,11 @@ public static PlanMatchPattern values(Map aliasToIndex) return values(aliasToIndex, Optional.empty(), Optional.empty()); } + public static PlanMatchPattern values(int rowCount) + { + return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of())); + } + public static PlanMatchPattern values(String... aliases) { return values(ImmutableList.copyOf(aliases)); @@ -713,6 +718,27 @@ public static PlanMatchPattern remoteSource(List sourceFragmentI return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 0000000000000..c14b68b443867 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,412 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReferences; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + switch (expected.getType()) { + case DescriptorArgumentValue.type: + DescriptorArgumentValue expectedDescriptor = (DescriptorArgumentValue) expected; + if (!(actual instanceof DescriptorArgument) || !expectedDescriptor.getDescriptor().equals(((DescriptorArgument) actual).getDescriptor())) { + return NO_MATCH; + } + break; + case ScalarArgumentValue.type: + ScalarArgumentValue expectedScalar = (ScalarArgumentValue) expected; + if (!(actual instanceof ScalarArgument) || !Objects.equals(expectedScalar.getValue(), ((ScalarArgument) actual).getValue())) { + return NO_MATCH; + } + break; + default: + if (!(actual instanceof TableArgument) || getMatchResult(symbolAliases, (TableArgumentValue) expected, tableFunctionNode, name).equals(NO_MATCH)) { + return NO_MATCH; + } + } + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + private MatchResult getMatchResult(SymbolAliases symbolAliases, TableArgumentValue expected, TableFunctionNode tableFunctionNode, String name) + { + TableArgumentValue expectedTableArgument = expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + return NO_MATCH; + } + + if (expectedTableArgument.specification().isPresent() != argumentProperties.getSpecification().isPresent()) { + return NO_MATCH; + } + if (!expectedTableArgument.specification() + .map(expectedSpecification -> matchSpecification(argumentProperties.getSpecification().get(), expectedSpecification.getExpectedValue(symbolAliases))) + .orElse(true)) { + return NO_MATCH; + } + Set expectedPassThrough = expectedTableArgument.passThroughVariables().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = toSymbolReferences( + argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(Collectors.toList())) + .stream() + .map(SymbolReference.class::cast) + .collect(Collectors.toSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + return match(symbolAliases); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + interface ArgumentValue + { + String getType(); + } + + public static class DescriptorArgumentValue + implements ArgumentValue + { + private final Optional descriptor; + public static final String type = "Descriptor"; + + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public String getType() + { + return type; + } + } + + public static class ScalarArgumentValue + implements ArgumentValue + { + private final Object value; + public static final String type = "Scalar"; + + public ScalarArgumentValue(Object value) + { + this.value = value; + } + + public Object getValue() + { + return value; + } + + @Override + public String getType() + { + return type; + } + } + + public static class TableArgumentValue + implements ArgumentValue + { + private final int sourceIndex; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Optional> specification; + private final Set passThroughVariables; + public static final String type = "Table"; + + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification, Set passThroughVariables) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + this.passThroughVariables = ImmutableSet.copyOf(passThroughVariables); + } + + public int sourceIndex() + { + return sourceIndex; + } + + public boolean rowSemantics() + { + return rowSemantics; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean passThroughColumns() + { + return passThroughColumns; + } + + public Set passThroughVariables() + { + return passThroughVariables; + } + + public Optional> specification() + { + return specification; + } + + @Override + public String getType() + { + return type; + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + private Set passThroughVariables = ImmutableSet.of(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder passThroughVariables(Set variables) + { + this.passThroughVariables = variables; + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughVariables); + } + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 0000000000000..4891c3eb021dd --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,239 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.QueryPlanner; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final List> passThroughSymbols; + private final List> requiredSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + private final Optional hashSymbol; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + List> passThroughSymbols, + List> requiredSymbols, + Optional> markerSymbols, + Optional> specification, + Optional hashSymbol) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = passThroughSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.requiredSymbols = requiredSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + List> expectedPassThrough = passThroughSymbols.stream() + .map(list -> list.stream() + .map(symbolAliases::get) + .collect(toImmutableList())) + .collect(toImmutableList()); + List> actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .map(list -> list.stream() + .map(PassThroughColumn::getOutputVariables) + .map(QueryPlanner::toSymbolReference) + .collect(toImmutableList())) + .collect(toImmutableList()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerVariables().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerVariables().get().entrySet().stream() + .collect(toImmutableMap(entry -> toSymbolReference(entry.getKey()), entry -> toSymbolReference(entry.getValue()))); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!matchSpecification(specification.get().getExpectedValue(symbolAliases), tableFunctionProcessorNode.getSpecification().orElseThrow(NoSuchElementException::new))) { + return NO_MATCH; + } + } + if (hashSymbol.isPresent()) { + if (!hashSymbol.map(symbolAliases::get).equals(tableFunctionProcessorNode.getHashSymbol().map(QueryPlanner::toSymbolReference))) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), toSymbolReference(tableFunctionProcessorNode.getProperOutputs().get(i))); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("requiredSymbols", requiredSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .add("hashSymbol", hashSymbol) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private List> passThroughSymbols = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + private Optional hashSymbol = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(List> passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder hashSymbol(String hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, requiredSymbols, markerSymbols, specification, hashSymbol)); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..bcae22ae6c623 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorColumns + extends BaseRuleTest +{ + @Test + public void testDoNotPruneProperOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("p")) + .source(p.values(p.variable("x")))))) + .doesNotFire(); + } + + @Test + public void testPrunePassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + } + + @Test + public void testReferencedPassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression x = p.variable("x"); + VariableReferenceExpression y = p.variable("y"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(y, y).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(x, y) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of("y", expression("y"), "b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("x", "y")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("b"))), + values("a", "b")))); + } + + @Test + public void testAllPassThroughOutputsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + } + + @Test + public void testNoSource() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper"))))) + .doesNotFire(); + } + + @Test + public void testMultipleTableArguments() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.project( + Assignments.builder().put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper")) + .passThroughSpecifications( + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(a, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(b, true))), + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, false)))) + .source(p.values(a, b, c, d)))); + }) + .matches(project( + ImmutableMap.of("b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of(), ImmutableList.of("b"), ImmutableList.of())), + values("a", "b", "c", "d")))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..68f56d320e396 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorSourceColumns + extends BaseRuleTest +{ + @Test + public void testPruneUnreferencedSymbol() + { + // symbols 'a', 'b', 'c', 'd', 'hash', and 'marker' are used by the node. + // symbol 'unreferenced' is pruned out. Also, the mapping for this symbol is removed from marker mappings + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression hash = p.variable("hash"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false)))) + .requiredSymbols(ImmutableList.of(ImmutableList.of(b))) + .markerSymbols(ImmutableMap.of( + a, marker, + b, marker, + c, marker, + d, marker, + unreferenced, marker)) + .specification(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_FIRST)))))) + .hashSymbol(hash) + .source(p.values(a, b, c, d, unreferenced, hash, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"))) + .markerSymbols(ImmutableMap.of( + "a", "marker", + "b", "marker", + "c", "marker", + "d", "marker")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_FIRST))) + .hashSymbol("hash"), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "hash", expression("hash"), + "marker", expression("marker")), + values("a", "b", "c", "d", "unreferenced", "hash", "marker")))); + } + + @Test + public void testPruneUnusedMarkerSymbol() + { + // symbol 'unreferenced' is pruned out because the node does not use it. + // also, the mapping for this symbol is removed from marker mappings. + // because the marker symbol 'marker' is no longer used, it is pruned out too. + // note: currently a marker symbol cannot become unused because the function + // must use at least one symbol from each source. it might change in the future. + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of(unreferenced, marker)) + .source(p.values(unreferenced, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of()), + project( + ImmutableMap.of(), + values("unreferenced", "marker")))); + } + + @Test + public void testMultipleSources() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + // the third argument provides symbols 'e', 'f', and 'unreferenced'. those symbols are mapped to common marker symbol 'marker3' + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression marker1 = p.variable("marker1"); + VariableReferenceExpression marker2 = p.variable("marker2"); + VariableReferenceExpression marker3 = p.variable("marker3"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true)))) + .requiredSymbols(ImmutableList.of( + ImmutableList.of(b), + ImmutableList.of(d), + ImmutableList.of(f))) + .markerSymbols(ImmutableMap.of( + a, marker1, + b, marker1, + c, marker2, + d, marker2, + e, marker3, + f, marker3, + unreferenced, marker3)) + .source(p.values(a, b, c, d, e, f, marker1, marker2, marker3, unreferenced))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"), ImmutableList.of("c"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"), ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "a", "marker1", + "b", "marker1", + "c", "marker2", + "d", "marker2", + "e", "marker3", + "f", "marker3")), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "e", expression("e"), + "f", expression("f"), + "marker1", expression("marker1"), + "marker2", expression("marker2"), + "marker3", expression("marker3")), + values("a", "b", "c", "d", "e", "f", "marker1", "marker2", "marker3", "unreferenced")))); + } + + @Test + public void allSymbolsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(a))) + .markerSymbols(ImmutableMap.of(a, marker)) + .source(p.values(a, marker))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java new file mode 100644 index 0000000000000..86b6cc74e74af --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRemoveRedundantTableFunctionProcessor + extends BaseRuleTest +{ + @Test + public void testRemoveTableFunction() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .matches(values("proper", "pass_through")); + } + + @Test + public void testDoNotRemoveKeepWhenEmpty() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .doesNotFire(); + } + + @Test + public void testDoNotRemoveNonEmptyInput() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(5, passThrough))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java new file mode 100644 index 0000000000000..bfdcea0b8ca5f --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.operator.table.ExcludeColumns.ExcludeColumnsFunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRewriteExcludeColumnsFunctionToProjection + extends BaseRuleTest +{ + @Test + public void rewriteExcludeColumnsFunction() + { + tester().assertThat(new RewriteExcludeColumnsFunctionToProjection()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BOOLEAN); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression c = p.variable("c", SMALLINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression y = p.variable("y", SMALLINT); + return p.tableFunctionProcessor( + builder -> builder + .name("exclude_columns") + .properOutputs(x, y) + .pruneWhenEmpty() + .requiredSymbols(ImmutableList.of(ImmutableList.of(b, c))) + .connectorHandle(new ExcludeColumnsFunctionHandle()) + .source(p.values(a, b, c))); + }) + .matches(PlanMatchPattern.strictProject( + ImmutableMap.of( + "x", expression("b"), + "y", expression("c")), + values("a", "b", "c"))); + } + + @Test + public void doNotRewriteOtherFunction() + { + tester().assertThat(new RewriteExcludeColumnsFunctionToProjection()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BOOLEAN); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression c = p.variable("c", SMALLINT); + return p.tableFunctionProcessor( + builder -> builder + .name("testing_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(b, c))) + .source(p.values(a, b, c))); + }).doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java new file mode 100644 index 0000000000000..b6fb48a904e81 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java @@ -0,0 +1,1404 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_FIRST; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; + +public class TestTransformTableFunctionToTableFunctionProcessor + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.variable("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + + // pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(true, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + VariableReferenceExpression h = p.variable("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(h, DESC_NULLS_FIRST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"), ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"), ImmutableList.of("h"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = input_3_row_number OR " + + "(combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST)) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("g", "h")))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR " + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM c) " + + "AND (" + + " input_2_row_number = input_1_row_number OR" + + " (input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR" + + " input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d)" + + " AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (combined_partition_column_1_2 IS DISTINCT FROM e) " + + "AND (" + + " combined_row_number_1_2 = input_3_row_number OR" + + " (combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR" + + " input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1'))"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e")))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(g, DESC_NULLS_FIRST)))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("g"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, null)"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + JoinType.LEFT, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = combined_row_number_3_4 OR " + + "(combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR " + + "combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (e IS DISTINCT FROM f) " + + "AND ( " + + "input_3_row_number = input_4_row_number OR " + + "(input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR " + + "input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))), + window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + window(builder -> builder + .specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST)) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())), + values("f", "g")))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + JoinType.INNER, + ImmutableList.of(), + Optional.of("combined_row_number_2_3 = input_1_row_number OR " + + "(combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR " + + "input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM e) " + + "AND ( " + + "input_2_row_number = input_3_row_number OR " + + "(input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c", TINYINT); + VariableReferenceExpression cCoerced = p.variable("c_coerced", INTEGER); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e", INTEGER); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, p.rowExpression("c")) + .put(d, p.rowExpression("d")) + .put(cCoerced, p.rowExpression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c_coerced IS DISTINCT FROM e) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d")))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND NOT (d IS DISTINCT FROM f) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 9329b3326441f..73f200c33c3a6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -20,6 +20,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.IndexHandle; @@ -29,6 +30,7 @@ import com.facebook.presto.spi.connector.RowChangeParadigm; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; @@ -92,6 +94,8 @@ import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; @@ -1006,6 +1010,32 @@ public WindowNode window(DataOrganizationSpecification specification, Map properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(new ConnectorId("connector_id"), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + + public TableFunctionProcessorNode tableFunctionProcessor(Consumer consumer) + { + TableFunctionProcessorBuilder tableFunctionProcessorBuilder = new TableFunctionProcessorBuilder(); + consumer.accept(tableFunctionProcessorBuilder); + return tableFunctionProcessorBuilder.build(idAllocator); + } + public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, VariableReferenceExpression rowNumberVariable, PlanNode source) { return new RowNumberNode( diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java new file mode 100644 index 0000000000000..404831b10f0ef --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class TableFunctionProcessorBuilder +{ + private String name; + private List properOutputs = ImmutableList.of(); + private Optional source = Optional.empty(); + private boolean pruneWhenEmpty; + private List passThroughSpecifications = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional specification = Optional.empty(); + private Set prePartitioned = ImmutableSet.of(); + private int preSorted; + private Optional hashSymbol = Optional.empty(); + private ConnectorTableFunctionHandle connectorHandle = new ConnectorTableFunctionHandle() {}; + + public TableFunctionProcessorBuilder() {} + + public TableFunctionProcessorBuilder name(String name) + { + this.name = name; + return this; + } + + public TableFunctionProcessorBuilder properOutputs(VariableReferenceExpression... properOutputs) + { + this.properOutputs = ImmutableList.copyOf(properOutputs); + return this; + } + + public TableFunctionProcessorBuilder source(PlanNode source) + { + this.source = Optional.of(source); + return this; + } + + public TableFunctionProcessorBuilder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public TableFunctionProcessorBuilder passThroughSpecifications(PassThroughSpecification... passThroughSpecifications) + { + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + return this; + } + + public TableFunctionProcessorBuilder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public TableFunctionProcessorBuilder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public TableFunctionProcessorBuilder specification(DataOrganizationSpecification specification) + { + this.specification = Optional.of(specification); + return this; + } + + public TableFunctionProcessorBuilder prePartitioned(Set prePartitioned) + { + this.prePartitioned = prePartitioned; + return this; + } + + public TableFunctionProcessorBuilder preSorted(int preSorted) + { + this.preSorted = preSorted; + return this; + } + + public TableFunctionProcessorBuilder hashSymbol(VariableReferenceExpression hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public TableFunctionProcessorBuilder connectorHandle(ConnectorTableFunctionHandle connectorHandle) + { + this.connectorHandle = connectorHandle; + return this; + } + + public TableFunctionProcessorNode build(PlanNodeIdAllocator idAllocator) + { + return new TableFunctionProcessorNode( + idAllocator.getNextId(), + name, + properOutputs, + source, + pruneWhenEmpty, + passThroughSpecifications, + requiredSymbols, + markerSymbols, + specification, + prePartitioned, + preSorted, + hashSymbol, + new TableFunctionHandle(new ConnectorId("connector_id"), connectorHandle, TestingTransactionHandle.create())); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index 5c9f7bbe6099d..82cd2dc7fece9 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -203,6 +203,7 @@ public void run() NodeInfo nodeInfo = injector.getInstance(NodeInfo.class); PluginNodeManager pluginNodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()); planCheckerProviderManager.loadPlanCheckerProviders(pluginNodeManager); + injector.getInstance(FunctionAndTypeManager.class).loadTVFProviders(pluginNodeManager); if (injector.getInstance(FeaturesConfig.class).isBuiltInSidecarFunctionsEnabled()) { List functions = injector.getInstance(WorkerFunctionRegistryTool.class).getWorkerFunctions(); diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml index 30f44d30d5c93..2ae94e38d7ee4 100644 --- a/presto-native-execution/pom.xml +++ b/presto-native-execution/pom.xml @@ -372,6 +372,12 @@ ${project.version} test + + + com.facebook.presto + presto-native-tvf + test + diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 9138bb9be597b..e7e4e12df2dab 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -22,6 +22,10 @@ add_library(presto_session_properties SessionProperties.cpp) target_link_libraries(presto_session_properties ${FOLLY_WITH_DEPENDENCIES}) +# if(PRESTO_ENABLE_TABLE_FUNCTIONS) add_subdirectory(tvf) +add_subdirectory(tvf) +# endif() + add_library( presto_server_lib Announcer.cpp @@ -110,6 +114,8 @@ target_link_libraries( pthread ) +target_link_libraries(presto_server_lib presto_tvf_exec presto_tvf_functions) + if(PRESTO_ENABLE_CUDF) target_link_libraries(presto_server_lib velox_cudf_exec) endif() diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 8445b45f217e3..081204e7ca911 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -43,6 +43,8 @@ #include "presto_cpp/main/operators/ShuffleExchangeSource.h" #include "presto_cpp/main/operators/ShuffleRead.h" #include "presto_cpp/main/operators/ShuffleWrite.h" +#include "presto_cpp/main/tvf/exec/TableFunctionTranslator.h" +#include "presto_cpp/main/tvf/functions/TableFunctionsRegistration.h" #include "presto_cpp/main/types/ExpressionOptimizer.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include "presto_cpp/main/types/VeloxPlanConversion.h" @@ -468,6 +470,37 @@ void PrestoServer::run() { http::kMimeTypeApplicationJson) .sendWithEOM(); }); + httpServer_->registerGet( + "/v1/functions/tvf", + [](proxygen::HTTPMessage* /*message*/, + const std::vector>& /*body*/, + proxygen::ResponseHandler* downstream) { + http::sendOkResponse(downstream, getTableValuedFunctionsMetadata()); + }); + httpServer_->registerPost( + "/v1/tvf/analyze", + [server = this]( + proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + http::sendOkResponse( + downstream, + getAnalyzedTableValueFunction( + util::extractMessageBody(body), + server->nativeWorkerPool_.get())); + }); + httpServer_->registerPost( + "/v1/tvf/splits", + [server = this]( + proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + http::sendOkResponse( + downstream, + getSplits( + util::extractMessageBody(body), + server->nativeWorkerPool_.get())); + }); if (systemConfig->enableRuntimeMetricsCollection()) { enableWorkerStatsReporting(); @@ -1408,6 +1441,10 @@ void PrestoServer::registerCustomOperators() { // which will allow server specific operator registration. velox::exec::Operator::registerOperator( std::make_unique()); + + // Table functions translator. + velox::exec::Operator::registerOperator( + std::make_unique()); } void PrestoServer::registerFunctions() { @@ -1423,6 +1460,8 @@ void PrestoServer::registerFunctions() { velox::connector::hasConnector("hive-hadoop2")) { hive::functions::registerHiveNativeFunctions(); } + + tvf::registerAllTableFunctions(prestoBuiltinFunctionPrefix_); } void PrestoServer::registerRemoteFunctions() { diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp index 52cfd06f2a15d..717593f9d5ac6 100644 --- a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp @@ -75,6 +75,9 @@ void registerConnectors() { std::make_unique("system")); registerPrestoToVeloxConnector( std::make_unique("$system@system")); + registerPrestoToVeloxConnector( + std::make_unique( + "system:com.facebook.presto.tvf.NativeTableFunctionSplit")); #ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR registerPrestoToVeloxConnector( diff --git a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp index 9251c70f1a257..2c4bbf9f70516 100644 --- a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp @@ -14,6 +14,7 @@ #include "presto_cpp/main/connectors/SystemConnector.h" #include "presto_cpp/main/PrestoTask.h" #include "presto_cpp/main/TaskManager.h" +#include "presto_cpp/main/tvf/exec/TableFunctionSplit.h" #include "velox/type/Timestamp.h" @@ -392,4 +393,48 @@ std::unique_ptr SystemPrestoToVeloxConnector::createConnectorProtocol() const { return std::make_unique(); } + +std::unique_ptr +TvfNativePrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const { + auto nativeSplit = dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + nativeSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( + ISerializable::deserialize( + folly::parseJson(nativeSplit->serializedTableFunctionSplitHandle))); +} + +std::unique_ptr +TvfNativePrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + auto systemColumn = dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + systemColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique(systemColumn->columnName); +} + +std::unique_ptr +TvfNativePrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) const { + auto systemLayout = + std::dynamic_pointer_cast( + tableHandle.connectorTableLayout); + VELOX_CHECK_NOT_NULL( + systemLayout, "Unexpected table handle type {}", tableHandle.connectorId); + return std::make_unique( + tableHandle.connectorId, + systemLayout->table.schemaName, + systemLayout->table.tableName); +} + +std::unique_ptr +TvfNativePrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h index e36426bb481db..6987873d64060 100644 --- a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h @@ -197,4 +197,28 @@ class SystemPrestoToVeloxConnector final : public PrestoToVeloxConnector { const final; }; +class TvfNativePrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit TvfNativePrestoToVeloxConnector(std::string connectorId) + : PrestoToVeloxConnector(std::move(connectorId)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) + const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt index 20020ea182e5d..7cf3f12f782df 100644 --- a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt @@ -11,7 +11,13 @@ # limitations under the License. add_library(presto_function_metadata OBJECT FunctionMetadata.cpp) -target_link_libraries(presto_function_metadata presto_common velox_function_registry) +target_link_libraries( + presto_function_metadata + presto_common + velox_function_registry + presto_tvf_spi + presto_types +) add_subdirectory(dynamic_registry) diff --git a/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp index 6c7262a83b1b8..3f9b08354c02f 100644 --- a/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp +++ b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp @@ -12,6 +12,10 @@ * limitations under the License. */ #include "presto_cpp/main/functions/FunctionMetadata.h" +#include "presto_cpp/main/tvf/spi/TableFunction.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/TypeParser.h" + #include "presto_cpp/main/common/Utils.h" #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/exec/Aggregate.h" @@ -21,6 +25,7 @@ using namespace facebook::velox; using namespace facebook::velox::exec; +using namespace facebook::presto::tvf; namespace facebook::presto { namespace { @@ -251,6 +256,101 @@ json buildWindowMetadata( return j; } +protocol::Descriptor buildDescriptor(const Descriptor& descriptor) { + // types could be empty, pre-process that case + auto names = descriptor.names(); + auto types = descriptor.types(); + std::vector fields; + for (int i = 0; i < names.size(); i++) { + std::shared_ptr type = (i < types.size()) + ? std::make_shared(types.at(i)->toString()) + : nullptr; + fields.emplace_back( + protocol::Field{std::make_shared(names.at(i)), type}); + } + return protocol::Descriptor{fields}; +} + +protocol::NativeDescriptor buildNativeDescriptor(const Descriptor& descriptor) { + // types could be empty, pre-process that case + auto names = descriptor.names(); + auto types = descriptor.types(); + std::vector fields; + for (int i = 0; i < names.size(); i++) { + std::shared_ptr type = (i < types.size()) + ? std::make_shared(types.at(i)->toString()) + : nullptr; + fields.emplace_back( + protocol::NativeField{ + std::make_shared(names.at(i)), type}); + } + return protocol::NativeDescriptor{fields}; +} + +std::vector> +buildArgumentSpecsList(TableArgumentSpecList argumentsSpec) { + std::vector> + argumentsSpecsList; + for (const auto argumentSpec : argumentsSpec) { + if (auto scalarArgumentSpec = + std::dynamic_pointer_cast( + argumentSpec)) { + auto scalarArgumentSpecification = + std::make_shared(); + scalarArgumentSpecification->name = scalarArgumentSpec->name(); + scalarArgumentSpecification->required = scalarArgumentSpec->required(); + scalarArgumentSpecification->type = + scalarArgumentSpec->rowType()->toString(); + argumentsSpecsList.emplace_back(scalarArgumentSpecification); + } else if ( + auto tableArgumentSpec = + std::dynamic_pointer_cast( + argumentSpec)) { + auto tableArgumentSpecification = + std::make_shared(); + tableArgumentSpecification->name = tableArgumentSpec->name(); + tableArgumentSpecification->passThroughColumns = false; + tableArgumentSpecification->pruneWhenEmpty = true; + tableArgumentSpecification->rowSemantics = true; + argumentsSpecsList.emplace_back(tableArgumentSpecification); + } else if ( + auto descriptorArgumentSpec = + std::dynamic_pointer_cast( + argumentSpec)) { + auto descriptorArgumentSpecification = + std::make_shared(); + descriptorArgumentSpecification->name = descriptorArgumentSpec->name(); + descriptorArgumentSpecification->defaultValue = + buildDescriptor(descriptorArgumentSpec->descriptor()); + descriptorArgumentSpecification->required = false; + argumentsSpecsList.emplace_back(descriptorArgumentSpecification); + } else { + VELOX_FAIL("Failed to convert to a valid argumentSpec"); + } + } + return argumentsSpecsList; +} + +std::shared_ptr buildReturnTypeSpecification( + ReturnSpecPtr returnSpec) { + auto returnTypeSpecification = returnSpec->returnType(); + if (returnTypeSpecification == + ReturnTypeSpecification::ReturnType::kGenericTable) { + std::shared_ptr + genericTableReturnTypeSpecification = + std::make_shared(); + return genericTableReturnTypeSpecification; + } else { + std::shared_ptr + describedTableReturnTypeSpecification = + std::make_shared(); + auto describedTable = + std::dynamic_pointer_cast(returnSpec); + describedTableReturnTypeSpecification->descriptor = + buildDescriptor(*(describedTable->descriptor())); + return describedTableReturnTypeSpecification; + } +} } // namespace json getFunctionsMetadata(const std::optional& catalog) { @@ -320,4 +420,145 @@ json getFunctionsMetadata(const std::optional& catalog) { return j; } +json getTableValuedFunctionsMetadata() { + json j; + // Get metadata for all registered table valued functions in velox. + const auto signatures = tableFunctions(); + for (const auto& entry : signatures) { + const auto parts = util::getFunctionNameParts(entry.first); + const auto functionName = parts[2]; + + protocol::JsonBasedTableFunctionMetadata function; + json tj; + function.functionName = entry.first; + function.returnTypeSpecification = + buildReturnTypeSpecification(getTableFunctionReturnType(entry.first)); + function.arguments = + buildArgumentSpecsList(getTableFunctionArgumentSpecs(entry.first)); + protocol::to_json(tj, function); + j[functionName] = tj; + } + return j; +} + +protocol::Map> +getRequiredColumns(const tvf::TableFunctionAnalysis* tableFunctionAnalysis) { + protocol::Map> + requiredColumns; + for (auto& [k, v] : tableFunctionAnalysis->requiredColumns()) { + std::vector values; + for (int i : v) { + values.emplace_back(i); + } + requiredColumns.insert({k, values}); + } + return requiredColumns; +} + +protocol::NativeTableFunctionHandle buildNativeTableFunctionHandle( + const TableFunctionHandlePtr tableFunctionHandle, + const std::string& functionName) { + protocol::NativeTableFunctionHandle handle; + handle.functionName = functionName; + handle.serializedTableFunctionHandle = + folly::toJson(tableFunctionHandle->serialize()); + return handle; +} + +protocol::NativeTableFunctionAnalysis getNativeTableFunctionAnalysis( + std::string functionName, + std::unordered_map> args) { + auto tableFunctionAnalysis = tvf::TableFunction::analyze(functionName, args); + protocol::NativeTableFunctionAnalysis nativeTableFunctionAnalysis; + nativeTableFunctionAnalysis.requiredColumns = + getRequiredColumns(tableFunctionAnalysis.get()); + nativeTableFunctionAnalysis.returnedType = nullptr; + if (tableFunctionAnalysis->returnType()) { + nativeTableFunctionAnalysis.returnedType = std::make_shared( + buildNativeDescriptor(*tableFunctionAnalysis->returnType())); + } + nativeTableFunctionAnalysis.handle = buildNativeTableFunctionHandle( + tableFunctionAnalysis->tableFunctionHandle(), functionName); + return nativeTableFunctionAnalysis; +} + +json getAnalyzedTableValueFunction( + const std::string& connectorTableMetadataJson, + velox::memory::MemoryPool* pool) { + TypeParser parser; + VeloxExprConverter exprConverter{pool, &parser}; + protocol::ConnectorTableMetadata connectorTableMetadata = + json::parse(connectorTableMetadataJson); + std::unordered_map> args; + for (const auto& entry : connectorTableMetadata.arguments) { + std::shared_ptr functionArg; + if (auto scalarArgument = + std::dynamic_pointer_cast(entry.second)) { + auto serializableNullableValue = + scalarArgument->nullableValue.serializable; + auto value = exprConverter.getConstantValue( + parser.parse(serializableNullableValue.type), + serializableNullableValue.block); + functionArg = std::make_shared( + value.inferType(), + BaseVector::createConstant( + value.inferType(), + value, + serializableNullableValue.block.data.size(), + pool)); + } else if ( + auto tableArgument = + std::dynamic_pointer_cast(entry.second)) { + std::vector fieldNames; + std::vector fieldTypes; + for (auto& arg : tableArgument->fields) { + fieldNames.push_back(boost::algorithm::to_lower_copy(*arg.name)); + fieldTypes.push_back(parser.parse(*arg.type)); + } + functionArg = std::make_shared( + ROW(std::move(fieldNames), std::move(fieldTypes))); + } else if ( + auto descriptorArgument = + std::dynamic_pointer_cast( + entry.second)) { + std::vector fieldNames; + std::vector fieldTypes; + for (auto& arg : descriptorArgument->descriptor->fields) { + fieldNames.push_back(boost::algorithm::to_lower_copy(*arg.name)); + //fieldTypes.push_back(parser.parse(*arg.type)); + } + functionArg = std::make_shared( + std::move(fieldNames) + // , std::move(fieldTypes) + ); + } else { + VELOX_UNSUPPORTED("Failed to convert to a valid Argument"); + } + args[entry.first] = functionArg; + } + return json(getNativeTableFunctionAnalysis( + connectorTableMetadata.functionName, args)); +} + +json getSplits( + const std::string& connectorTableFunctionHandle, + velox::memory::MemoryPool* pool) { + protocol::NativeTableFunctionHandle handle = + json::parse(connectorTableFunctionHandle); + + const auto splits = tvf::TableFunction::getSplits( + handle.functionName, + ISerializable::deserialize( + folly::parseJson(handle.serializedTableFunctionHandle))); + + json j = json::array(); + protocol::NativeTableFunctionSplit jsonBasedTableFunctionSplit; + for (const auto& entry : splits) { + json tj; + jsonBasedTableFunctionSplit.serializedTableFunctionSplitHandle = folly::toJson(entry->serialize()); + to_json(tj, jsonBasedTableFunctionSplit); + j.push_back(tj); + } + return j; +} } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h index d2a2c66d7a489..8ea61ec06193d 100644 --- a/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h +++ b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h @@ -16,6 +16,7 @@ #include #include "presto_cpp/external/json/nlohmann/json.hpp" +#include "velox/common/memory/MemoryPool.h" namespace facebook::presto { @@ -23,4 +24,15 @@ namespace facebook::presto { nlohmann::json getFunctionsMetadata( const std::optional& catalog = std::nullopt); +// Returns metadata for all registered table valued functions as json. +nlohmann::json getTableValuedFunctionsMetadata(); + +nlohmann::json getAnalyzedTableValueFunction( + const std::string& connectorTableMetadataJson, + velox::memory::MemoryPool* pool); + +nlohmann::json getSplits( + const std::string& connectorTableFunctionHandle, + velox::memory::MemoryPool* pool); + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/tvf/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tvf/CMakeLists.txt new file mode 100644 index 0000000000000..206aec8589ead --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/CMakeLists.txt @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(core) +add_subdirectory(exec) +add_subdirectory(spi) +add_subdirectory(functions) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/tvf/core/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tvf/core/CMakeLists.txt new file mode 100644 index 0000000000000..332eb3455a5ae --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/core/CMakeLists.txt @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_tvf_core OBJECT TableFunctionProcessorNode.cpp) + +target_link_libraries(presto_tvf_core velox_core) diff --git a/presto-native-execution/presto_cpp/main/tvf/core/TableFunctionProcessorNode.cpp b/presto-native-execution/presto_cpp/main/tvf/core/TableFunctionProcessorNode.cpp new file mode 100644 index 0000000000000..6e5023d3b6465 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/core/TableFunctionProcessorNode.cpp @@ -0,0 +1,228 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; +using namespace facebook::velox::core; + +TableFunctionProcessorNode::TableFunctionProcessorNode( + PlanNodeId id, + std::string name, + TableFunctionHandlePtr handle, + std::vector partitionKeys, + std::vector sortingKeys, + std::vector sortingOrders, + velox::RowTypePtr outputType, + std::vector requiredColumns, + std::vector sources) + : PlanNode(std::move(id)), + functionName_(std::move(name)), + handle_(std::move(handle)), + partitionKeys_(std::move(partitionKeys)), + sortingKeys_(std::move(sortingKeys)), + sortingOrders_(std::move(sortingOrders)), + outputType_(std::move(outputType)), + requiredColumns_(std::move(requiredColumns)), + sources_{std::move(sources)} { + VELOX_CHECK_EQ( + sortingKeys_.size(), + sortingOrders_.size(), + "Number of sorting keys must be equal to the number of sorting orders"); + + std::unordered_set keyNames; + for (const auto& key : partitionKeys_) { + VELOX_USER_CHECK( + keyNames.insert(key->name()).second, + "Partitioning keys must be unique. Found duplicate key: {}", + key->name()); + } + + for (const auto& key : sortingKeys_) { + VELOX_USER_CHECK( + keyNames.insert(key->name()).second, + "Sorting keys must be unique and not overlap with partitioning keys. Found duplicate key: {}", + key->name()); + } + + VELOX_CHECK_LE( + sources_.size(), 1, "Number of sources must be equal to 0 or 1"); +} + +namespace { +void appendComma(int32_t i, std::stringstream& sql) { + if (i > 0) { + sql << ", "; + } +} + +void addFields( + std::stringstream& stream, + const std::vector& keys) { + for (auto i = 0; i < keys.size(); ++i) { + appendComma(i, stream); + stream << keys[i]->name(); + } +} + +void addKeys(std::stringstream& stream, const std::vector& keys) { + for (auto i = 0; i < keys.size(); ++i) { + const auto& expr = keys[i]; + appendComma(i, stream); + if (auto field = TypedExprs::asFieldAccess(expr)) { + stream << field->name(); + } else if (auto constant = TypedExprs::asConstant(expr)) { + stream << constant->toString(); + } else { + stream << expr->toString(); + } + } +} + +void addSortingKeys( + const std::vector& sortingKeys, + const std::vector& sortingOrders, + std::stringstream& stream) { + for (auto i = 0; i < sortingKeys.size(); ++i) { + appendComma(i, stream); + stream << sortingKeys[i]->name() << " " << sortingOrders[i].toString(); + } +} + +} // namespace + +void TableFunctionProcessorNode::addDetails(std::stringstream& stream) const { + if (!partitionKeys_.empty()) { + stream << "partition by ["; + addFields(stream, partitionKeys_); + stream << "] "; + } + + if (!sortingKeys_.empty()) { + stream << "order by ["; + addSortingKeys(sortingKeys_, sortingOrders_, stream); + stream << "] "; + } +} + +namespace { +folly::dynamic serializeSortingOrders( + const std::vector& sortingOrders) { + auto array = folly::dynamic::array(); + for (const auto& order : sortingOrders) { + array.push_back(order.serialize()); + } + + return array; +} + +std::vector deserializeSortingOrders(const folly::dynamic& array) { + std::vector sortingOrders; + sortingOrders.reserve(array.size()); + for (const auto& order : array) { + sortingOrders.push_back(SortOrder::deserialize(order)); + } + return sortingOrders; +} + +} // namespace + +folly::dynamic TableFunctionProcessorNode::serialize() const { + auto obj = PlanNode::serialize(); + if (handle_) { + obj["handle"] = handle_->serialize(); + } + + obj["partitionKeys"] = ISerializable::serialize(partitionKeys_); + obj["sortingKeys"] = ISerializable::serialize(sortingKeys_); + obj["sortingOrders"] = serializeSortingOrders(sortingOrders_); + + obj["functionName"] = functionName_.data(); + obj["outputType"] = outputType_->serialize(); + + obj["requiredColumns"] = ISerializable::serialize(requiredColumns_); + + return obj; +} + +namespace { +std::vector deserializeSources( + const folly::dynamic& obj, + void* context) { + if (obj.count("sources")) { + return ISerializable::deserialize>( + obj["sources"], context); + } + + return {}; +} + +PlanNodePtr deserializeSingleSource(const folly::dynamic& obj, void* context) { + auto sources = deserializeSources(obj, context); + VELOX_CHECK_EQ(1, sources.size()); + + return sources[0]; +} + +PlanNodeId deserializePlanNodeId(const folly::dynamic& obj) { + return obj["id"].asString(); +} + +RowTypePtr deserializeRowType(const folly::dynamic& obj) { + return ISerializable::deserialize(obj); +} + +std::vector deserializeFields( + const folly::dynamic& array, + void* context) { + return ISerializable::deserialize>( + array, context); +} + +} // namespace + +// static +PlanNodePtr TableFunctionProcessorNode::create( + const folly::dynamic& obj, + void* context) { + auto sources = deserializeSources(obj, context); + auto outputType = deserializeRowType(obj["outputType"]); + auto handle = ISerializable::deserialize(obj["handle"]); + VELOX_CHECK(handle); + + auto partitionKeys = deserializeFields(obj["partitionKeys"], context); + auto sortingKeys = deserializeFields(obj["sortingKeys"], context); + + auto sortingOrders = deserializeSortingOrders(obj["sortingOrders"]); + + auto name = obj["functionName"].asString(); + + auto requiredColumns = + deserialize>(obj["requiredColumns"]); + + return std::make_shared( + deserializePlanNodeId(obj), + name, + handle, + partitionKeys, + sortingKeys, + sortingOrders, + outputType, + requiredColumns, + sources); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/core/TableFunctionProcessorNode.h b/presto-native-execution/presto_cpp/main/tvf/core/TableFunctionProcessorNode.h new file mode 100644 index 0000000000000..88253a16bd041 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/core/TableFunctionProcessorNode.h @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/tvf/spi/TableFunction.h" + +#include "velox/core/Expressions.h" +#include "velox/core/PlanNode.h" + +namespace facebook::presto::tvf { + +class TableFunctionProcessorNode : public velox::core::PlanNode { + public: + TableFunctionProcessorNode( + velox::core::PlanNodeId id, + std::string name, + TableFunctionHandlePtr handle, + std::vector partitionKeys, + std::vector sortingKeys, + std::vector sortingOrders, + velox::RowTypePtr outputType, + std::vector requiredColumns, + std::vector sources); + + const std::vector& sources() const override { + return sources_; + } + + bool canSpill(const velox::core::QueryConfig& queryConfig) const override { + return false; + } + + const velox::RowTypePtr& inputType() const { + return sources_[0]->outputType(); + } + + const velox::RowTypePtr& outputType() const override { + return outputType_; + }; + + std::string_view name() const override { + return "TableFunctionProcessor"; + } + + const std::string functionName() const { + return functionName_; + } + + const std::shared_ptr handle() const { + return handle_; + } + + const std::vector& partitionKeys() + const { + return partitionKeys_; + } + + const std::vector& sortingKeys() const { + return sortingKeys_; + } + + const std::vector& sortingOrders() const { + return sortingOrders_; + } + + const std::vector& requiredColumns() const { + return requiredColumns_; + } + + bool requiresSplits() const override { + if (sources_.empty()) { + // This is a leaf operator that needs splits then. + return true; + } + + return false; + } + + folly::dynamic serialize() const override; + + static velox::core::PlanNodePtr create( + const folly::dynamic& obj, + void* context); + + private: + void addDetails(std::stringstream& stream) const override; + + const std::string functionName_; + + TableFunctionHandlePtr handle_; + + std::vector partitionKeys_; + std::vector sortingKeys_; + std::vector sortingOrders_; + + const velox::RowTypePtr outputType_; + + const std::vector requiredColumns_; + + const std::vector sources_; +}; + +using TableFunctionProcessorNodePtr = + std::shared_ptr; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tvf/exec/CMakeLists.txt new file mode 100644 index 0000000000000..4c46a57998e9b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/CMakeLists.txt @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# target_include_directories(presto_tvf_exec PRIVATE +# ${CMAKE_SOURCE_DIR}/presto_cpp/main/tvf/core) + +add_library( + presto_tvf_exec + LeafTableFunctionOperator.cpp + TableFunctionOperator.cpp + TableFunctionPartition.cpp + TablePartitionBuild.cpp +) + +target_link_libraries(presto_tvf_exec presto_tvf_core presto_tvf_spi velox_exec) diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/LeafTableFunctionOperator.cpp b/presto-native-execution/presto_cpp/main/tvf/exec/LeafTableFunctionOperator.cpp new file mode 100644 index 0000000000000..098f9928f2ecf --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/LeafTableFunctionOperator.cpp @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/exec/LeafTableFunctionOperator.h" + +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/exec/Task.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +LeafTableFunctionOperator::LeafTableFunctionOperator( + int32_t operatorId, + DriverCtx* driverCtx, + const TableFunctionProcessorNodePtr& tableFunctionProcessorNode) + : SourceOperator( + driverCtx, + tableFunctionProcessorNode->outputType(), + operatorId, + tableFunctionProcessorNode->id(), + "LeafTableFunctionOperator"), + driverCtx_(driverCtx), + pool_(pool()), + stringAllocator_(pool_), + tableFunctionProcessorNode_(tableFunctionProcessorNode), + result_(nullptr) { + VELOX_CHECK(tableFunctionProcessorNode->sources().empty()); + VELOX_CHECK(tableFunctionProcessorNode->partitionKeys().empty()); + VELOX_CHECK(tableFunctionProcessorNode->sortingKeys().empty()); + VELOX_CHECK(tableFunctionProcessorNode->sortingOrders().empty()); +} + +void LeafTableFunctionOperator::initialize() { + Operator::initialize(); + VELOX_CHECK_NOT_NULL(tableFunctionProcessorNode_); + // TODO: Why was this needed + // tableFunctionProcessorNode_.reset(); +} + +void LeafTableFunctionOperator::createTableFunctionSplitProcessor() { + splitProcessor_ = TableFunction::createSplitProcessor( + tableFunctionProcessorNode_->functionName(), + tableFunctionProcessorNode_->handle(), + pool_, + &stringAllocator_, + operatorCtx_->driverCtx()->queryConfig()); + VELOX_CHECK(splitProcessor_); +} + +RowVectorPtr LeafTableFunctionOperator::getOutput() { + if (noMoreSplits_) { + return nullptr; + } + + if (currentSplit_ == nullptr) { + // Try to retrieve the next split. If no more splits then return. + exec::Split split; + blockingReason_ = driverCtx_->task->getSplitOrFuture( + driverCtx_->splitGroupId, + planNodeId(), + split, + blockingFuture_, + 0, + splitPreloader_); + + if (blockingReason_ != BlockingReason::kNotBlocked) { + return nullptr; + } + + if (!split.hasConnectorSplit()) { + noMoreSplits_ = true; + return nullptr; + } + + currentSplit_ = + std::dynamic_pointer_cast(split.connectorSplit); + VELOX_CHECK(currentSplit_, "Invalid Table Function Split"); + + createTableFunctionSplitProcessor(); + } + + // This split could be one retrieved above or a incompletely processed one + // from the previous getOutput. + VELOX_CHECK_NOT_NULL(currentSplit_, "No split to process."); + + // GetOutput from table function. + VELOX_CHECK(splitProcessor_); + auto result = splitProcessor_->apply(currentSplit_->splitHandle()); + if (result->state() == TableFunctionResult::TableFunctionState::kFinished) { + // Clear the split as the input rows are completely consumed. + currentSplit_ = nullptr; + splitProcessor_ = nullptr; + return nullptr; + } + + VELOX_CHECK( + result->state() == TableFunctionResult::TableFunctionState::kProcessed); + // TODO: Figure what usedInput means for apply with splits. + // VELOX_CHECK(!result->usedInput()); + + auto resultRows = result->result(); + VELOX_CHECK(resultRows); + + return std::move(resultRows); +} + +void LeafTableFunctionOperator::reclaim( + uint64_t /*targetBytes*/, + memory::MemoryReclaimer::Stats& stats) { + VELOX_NYI("LeafTableFunctionOperator::reclaim not implemented"); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/LeafTableFunctionOperator.h b/presto-native-execution/presto_cpp/main/tvf/exec/LeafTableFunctionOperator.h new file mode 100644 index 0000000000000..b2aa192f904e8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/LeafTableFunctionOperator.h @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" +#include "presto_cpp/main/tvf/exec/TableFunctionSplit.h" + +#include "velox/common/memory/HashStringAllocator.h" +#include "velox/exec/Operator.h" +#include "velox/exec/RowContainer.h" +#include "velox/vector/DecodedVector.h" + +namespace facebook::presto::tvf { + +class LeafTableFunctionOperator : public velox::exec::SourceOperator { + public: + LeafTableFunctionOperator( + int32_t operatorId, + velox::exec::DriverCtx* driverCtx, + const std::shared_ptr& + tableFunctionProcessorNode); + + void initialize() override; + + velox::RowVectorPtr getOutput() override; + + velox::exec::BlockingReason isBlocked( + velox::ContinueFuture* /* unused */) override { + return velox::exec::BlockingReason::kNotBlocked; +} + +bool isFinished() override { + return noMoreSplits_; +} + +void reclaim(uint64_t targetBytes, velox::memory::MemoryReclaimer::Stats& stats) + override; + +private: +bool spillEnabled() const { + return spillConfig_.has_value(); +} + +void createTableFunctionSplitProcessor(); + +void clear(); + +velox::exec::DriverCtx* const driverCtx_; + +velox::memory::MemoryPool* pool_; +// HashStringAllocator required by functions that allocate out of line +// buffers. +velox::HashStringAllocator stringAllocator_; + +std::shared_ptr tableFunctionProcessorNode_; + +std::shared_ptr result_; + +// This should be constructed for each split. +std::unique_ptr splitProcessor_; + +bool noMoreSplits_ = false; +std::shared_ptr currentSplit_; + +velox::ContinueFuture blockingFuture_{velox::ContinueFuture::makeEmpty()}; +velox::exec::BlockingReason blockingReason_{ + velox::exec::BlockingReason::kNotBlocked}; +std::function&)> + splitPreloader_{nullptr}; +} +; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionOperator.cpp b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionOperator.cpp new file mode 100644 index 0000000000000..6449e027b4d0b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionOperator.cpp @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/exec/TableFunctionOperator.h" + +#include "presto_cpp/main/tvf/exec/TableFunctionPartition.h" + +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +namespace { + +const RowTypePtr requiredColumnType( + const std::string& name, + const TableFunctionProcessorNodePtr& tableFunctionProcessorNode) { + auto columns = tableFunctionProcessorNode->requiredColumns(); + + // TODO: This assumes single source. + auto inputType = tableFunctionProcessorNode->sources()[0]->outputType(); + std::vector names; + std::vector types; + for (const auto& idx : columns) { + names.push_back(inputType->nameOf(idx)); + types.push_back(inputType->childAt(idx)); + } + return ROW(std::move(names), std::move(types)); +} +} // namespace + +TableFunctionOperator::TableFunctionOperator( + int32_t operatorId, + DriverCtx* driverCtx, + const TableFunctionProcessorNodePtr& tableFunctionProcessorNode) + : Operator( + driverCtx, + tableFunctionProcessorNode->outputType(), + operatorId, + tableFunctionProcessorNode->id(), + "TableFunctionOperator", + tableFunctionProcessorNode->canSpill(driverCtx->queryConfig()) + ? driverCtx->makeSpillConfig(operatorId) + : std::nullopt), + pool_(pool()), + stringAllocator_(pool_), + tableFunctionProcessorNode_(tableFunctionProcessorNode), + inputType_(tableFunctionProcessorNode->sources()[0]->outputType()), + requiredColummType_(requiredColumnType("t1", tableFunctionProcessorNode)), + tableFunctionPartition_(nullptr), + input_(nullptr) { + tablePartitionBuild_ = std::make_unique( + inputType_, + tableFunctionProcessorNode->partitionKeys(), + tableFunctionProcessorNode->sortingKeys(), + tableFunctionProcessorNode->sortingOrders(), + pool(), + common::PrefixSortConfig{ + driverCtx->queryConfig().prefixSortNormalizedKeyMaxBytes(), + driverCtx->queryConfig().prefixSortMinRows(), + driverCtx->queryConfig().prefixSortMaxStringPrefixLength()}); + numRowsPerOutput_ = outputBatchRows(tablePartitionBuild_->estimateRowSize()); +} + +void TableFunctionOperator::initialize() { + Operator::initialize(); + VELOX_CHECK_NOT_NULL(tableFunctionProcessorNode_); +} + +void TableFunctionOperator::createTableFunctionDataProcessor( + const std::shared_ptr& node) { + dataProcessor_ = TableFunction::createDataProcessor( + node->functionName(), + node->handle(), + pool_, + &stringAllocator_, + operatorCtx_->driverCtx()->queryConfig()); + VELOX_CHECK(dataProcessor_); +} + +// Writing the code to add the input rows -> call TableFunction::process and +// return the rows from it. This is done per input vectors basis. If we have +// partition by an order by this would need a change but just testing with a +// simple model for now. +void TableFunctionOperator::addInput(RowVectorPtr input) { + numRows_ += input->size(); + + tablePartitionBuild_->addInput(input); +} + +void TableFunctionOperator::noMoreInput() { + Operator::noMoreInput(); + tablePartitionBuild_->noMoreInput(); +} + +void TableFunctionOperator::assembleInput() { + VELOX_CHECK(tableFunctionPartition_); + + const auto numRowsLeft = + tableFunctionPartition_->numRows() - numPartitionProcessedRows_; + VELOX_CHECK_GT(numRowsLeft, 0); + const auto numOutputRows = std::min(numRowsPerOutput_, numRowsLeft); + auto input = + BaseVector::create(requiredColummType_, numOutputRows, pool_); + + auto columns = tableFunctionProcessorNode_->requiredColumns(); + for (int i = 0; i < requiredColummType_->children().size(); i++) { + input->childAt(i)->resize(numOutputRows); + tableFunctionPartition_->extractColumn( + columns[i], + numPartitionProcessedRows_, + numOutputRows, + 0, + input->childAt(i)); + } + input_ = std::move(input); +} + +RowVectorPtr TableFunctionOperator::getOutput() { + if (!noMoreInput_) { + return nullptr; + } + + if (numRows_ == 0) { + return nullptr; + } + + const auto numRowsLeft = numRows_ - numProcessedRows_; + if (numRowsLeft == 0) { + return nullptr; + } + + if (tableFunctionPartition_ == nullptr || + (!input_ && + (tableFunctionPartition_->numRows() - numPartitionProcessedRows_ == + 0))) { + if (tablePartitionBuild_->hasNextPartition()) { + tableFunctionPartition_ = tablePartitionBuild_->nextPartition(); + createTableFunctionDataProcessor(tableFunctionProcessorNode_); + numPartitionProcessedRows_ = 0; + } else { + // There is no partition to output. + return nullptr; + } + } + + // This is the first call to TableFunction::apply for this partition + // or a previous apply for this input has completed. + if (input_ == nullptr) { + assembleInput(); + } + + VELOX_CHECK(dataProcessor_); + auto result = dataProcessor_->apply({input_}); + if (result->state() == TableFunctionResult::TableFunctionState::kFinished) { + // Skip the rest of this partition processing. + numProcessedRows_ += + (tableFunctionPartition_->numRows() - numPartitionProcessedRows_); + tableFunctionPartition_ = nullptr; + numPartitionProcessedRows_ = 0; + return nullptr; + } + + VELOX_CHECK( + result->state() == TableFunctionResult::TableFunctionState::kProcessed); + auto resultRows = result->result(); + VELOX_CHECK(resultRows); + if (result->usedInput()) { + // The input rows were consumed, so we need to re-assemble input at the + // next call. + numPartitionProcessedRows_ += input_->size(); + numProcessedRows_ += input_->size(); + input_ = nullptr; + } + return std::move(resultRows); +} + +void TableFunctionOperator::reclaim( + uint64_t /*targetBytes*/, + memory::MemoryReclaimer::Stats& stats) { + VELOX_NYI("TableFunctionOperator::reclaim not implemented"); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionOperator.h b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionOperator.h new file mode 100644 index 0000000000000..bb5878fc7e7a8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionOperator.h @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" +#include "presto_cpp/main/tvf/exec/TableFunctionPartition.h" +#include "presto_cpp/main/tvf/exec/TablePartitionBuild.h" + +#include "velox/common/memory/HashStringAllocator.h" +#include "velox/exec/Operator.h" +#include "velox/exec/RowContainer.h" +#include "velox/vector/DecodedVector.h" + +namespace facebook::presto::tvf { + +class TableFunctionOperator : public velox::exec::Operator { + public: + TableFunctionOperator( + int32_t operatorId, + velox::exec::DriverCtx* driverCtx, + const std::shared_ptr& + tableFunctionProcessorNode); + + void initialize() override; + + void addInput(velox::RowVectorPtr input) override; + + void noMoreInput() override; + + velox::RowVectorPtr getOutput() override; + + bool needsInput() const override { + return !noMoreInput_; + } + + velox::exec::BlockingReason isBlocked( + velox::ContinueFuture* /* unused */) override { + return velox::exec::BlockingReason::kNotBlocked; + } + + bool isFinished() override { + // There is no input and the function has completed as well. + return (noMoreInput_ && input_ == nullptr); + } + + void reclaim(uint64_t targetBytes, velox::memory::MemoryReclaimer::Stats& stats) + override; + + private: + bool spillEnabled() const { + return spillConfig_.has_value(); + } + + void createTableFunctionDataProcessor( + const std::shared_ptr& + tableFunctionProcessorNode); + + void assembleInput(); + + velox::memory::MemoryPool* pool_; + // HashStringAllocator required by functions that allocate out of line + // buffers. + velox::HashStringAllocator stringAllocator_; + + std::shared_ptr tableFunctionProcessorNode_; + + // TODO : Figure how this works for a multi-input table parameter case. + velox::RowTypePtr inputType_; + + // This would be a list when the operator supports multiple TableArguments. + const velox::RowTypePtr requiredColummType_; + + // TablePartitionBuild is used to store input rows and return + // TableFunctionPartitions for the processing. + std::unique_ptr tablePartitionBuild_; + + std::shared_ptr tableFunctionPartition_; + + velox::RowVectorPtr input_; + + // This should be constructed for each partition. + std::unique_ptr dataProcessor_; + + velox::vector_size_t numRows_ = 0; + velox::vector_size_t numProcessedRows_ = 0; + velox::vector_size_t numPartitionProcessedRows_ = 0; + // Number of rows that be fit into an output block. + velox::vector_size_t numRowsPerOutput_; +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionPartition.cpp b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionPartition.cpp new file mode 100644 index 0000000000000..b85d6083a3a56 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionPartition.cpp @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/exec/TableFunctionPartition.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +TableFunctionPartition::TableFunctionPartition( + RowContainer* data, + const folly::Range& rows, + const std::vector& inputMapping) + : data_(data), partition_(rows), inputMapping_(inputMapping) {} + +TableFunctionPartition::~TableFunctionPartition() { + partition_.clear(); +} + +void TableFunctionPartition::extractColumn( + int32_t columnIndex, + folly::Range rowNumbers, + vector_size_t resultOffset, + const VectorPtr& result) const { + RowContainer::extractColumn( + partition_.data(), + rowNumbers, + data_->columnAt(inputMapping_[columnIndex]), + data_->columnHasNulls(inputMapping_[columnIndex]), + resultOffset, + result); +} + +void TableFunctionPartition::extractColumn( + int32_t columnIndex, + vector_size_t partitionOffset, + vector_size_t numRows, + vector_size_t resultOffset, + const VectorPtr& result) const { + RowContainer::extractColumn( + partition_.data() + partitionOffset, + numRows, + data_->columnAt(inputMapping_[columnIndex]), + data_->columnHasNulls(inputMapping_[columnIndex]), + resultOffset, + result); +} + +void TableFunctionPartition::extractNulls( + int32_t columnIndex, + vector_size_t partitionOffset, + vector_size_t numRows, + const BufferPtr& nullsBuffer) const { + RowContainer::extractNulls( + partition_.data() + partitionOffset, + numRows, + data_->columnAt(inputMapping_[columnIndex]), + nullsBuffer); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionPartition.h b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionPartition.h new file mode 100644 index 0000000000000..b75172941d811 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionPartition.h @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/RowContainer.h" +#include "velox/vector/BaseVector.h" + +/// Simple TableFunctionPartition that builds over the RowContainer used for +/// storing the input rows in the Table Function Operator. This works completely +/// in-memory. + +namespace facebook::presto::tvf { + +class TableFunctionPartition { + public: + /// The TableFunctionPartition is used by the TableFunctionOperator and + /// TableFunction objects to access the underlying data and columns of a + /// partition of rows. + /// The TableFunctionPartition is constructed by TableFunctionOperator + /// from the input data. + /// 'data' : Underlying RowContainer of the TableFunctionOperator. + TableFunctionPartition( + velox::exec::RowContainer* data, + const folly::Range& rows, + const std::vector& inputMapping); + + ~TableFunctionPartition(); + + /// Returns the number of rows in the current TableFunctionPartition. + velox::vector_size_t numRows() const { + return partition_.size(); + } + + /// Copies the values at 'columnIndex' into 'result' (starting at + /// 'resultOffset') for the rows at positions in the 'rowNumbers' + /// array from the partition input data. + void extractColumn( + int32_t columnIndex, + folly::Range rowNumbers, + velox::vector_size_t resultOffset, + const velox::VectorPtr& result) const; + + /// Copies the values at 'columnIndex' into 'result' (starting at + /// 'resultOffset') for 'numRows' starting at positions 'partitionOffset' + /// in the partition input data. + void extractColumn( + int32_t columnIndex, + velox::vector_size_t partitionOffset, + velox::vector_size_t numRows, + velox::vector_size_t resultOffset, + const velox::VectorPtr& result) const; + + /// Extracts null positions at 'columnIndex' into 'nullsBuffer' for + /// 'numRows' starting at positions 'partitionOffset' in the partition + /// input data. + void extractNulls( + int32_t columnIndex, + velox::vector_size_t partitionOffset, + velox::vector_size_t numRows, + const velox::BufferPtr& nullsBuffer) const; + + private: + // The RowContainer associated with the partition. + // It is owned by the TablePartitionBuild that creates the partition. + velox::exec::RowContainer* const data_; + + // folly::Range is for the partition rows iterator provided by the + // TableFunctionOperator. The pointers are to rows from a RowContainer owned + // by the operator. We can assume these are valid values for the lifetime + // of TableFunctionPartition. + folly::Range partition_; + + // Mapping from window input column -> index in data_. This is required + // because the TableFunctionPartitionBuild reorders data_ to place partition + // and sort keys before other columns in data_. But the TableFunctionOperator + // and TableFunction code accesses TableFunctionPartition using the + // indexes of TableFunction input type. + const std::vector inputMapping_; +}; +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionSplit.h b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionSplit.h new file mode 100644 index 0000000000000..28bcfcb7bf4d9 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionSplit.h @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "presto_cpp/main/tvf/spi/TableFunction.h" +#include "velox/connectors/Connector.h" + +namespace facebook::presto::tvf { + +struct TableFunctionSplit : public velox::connector::ConnectorSplit { + explicit TableFunctionSplit(const TableSplitHandlePtr& handle) + : ConnectorSplit(""), splitHandle_(handle) {} + + const TableSplitHandlePtr splitHandle() { + return splitHandle_; + } + + private: + const TableSplitHandlePtr splitHandle_; +}; + +} // namespace facebook::presto::tvf + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::presto::tvf::TableFunctionSplit s, + format_context& ctx) { + return formatter::format(s.toString(), ctx); + } +}; + +template <> +struct fmt::formatter< + std::shared_ptr> + : formatter { + auto format( + std::shared_ptr s, + format_context& ctx) const { + return formatter::format(s->toString(), ctx); + } +}; diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionTranslator.h b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionTranslator.h new file mode 100644 index 0000000000000..46d288dd967cb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TableFunctionTranslator.h @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" +#include "presto_cpp/main/tvf/exec/LeafTableFunctionOperator.h" +#include "presto_cpp/main/tvf/exec/TableFunctionOperator.h" + +#include "velox/exec/Operator.h" + +namespace facebook::presto::tvf { + +// Custom translation logic to hook into Velox Driver. +class TableFunctionTranslator + : public velox::exec::Operator::PlanNodeTranslator { + std::unique_ptr toOperator( + velox::exec::DriverCtx* ctx, + int32_t id, + const velox::core::PlanNodePtr& node) { + if (auto tableFunctionProcessorNode = + std::dynamic_pointer_cast(node)) { + if (tableFunctionProcessorNode->sources().empty()) { + return std::make_unique( + id, ctx, tableFunctionProcessorNode); + } + return std::make_unique( + id, ctx, tableFunctionProcessorNode); + } + return nullptr; + } +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TablePartitionBuild.cpp b/presto-native-execution/presto_cpp/main/tvf/exec/TablePartitionBuild.cpp new file mode 100644 index 0000000000000..9b495badeee52 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TablePartitionBuild.cpp @@ -0,0 +1,278 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/exec/TablePartitionBuild.h" +#include + +#include "velox/exec/Operator.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; + +namespace { +std::tuple, std::vector, RowTypePtr> +reorderInputChannels( + const RowTypePtr& inputType, + const std::vector& partitionKeys, + const std::vector& sortingKeys) { + const auto size = inputType->size(); + + std::vector channels; + std::vector inversedChannels; + std::vector names; + std::vector types; + channels.reserve(size); + inversedChannels.resize(size); + names.reserve(size); + types.reserve(size); + + std::unordered_set keyNames; + + auto appendChannel = + [&inputType, &channels, &inversedChannels, &names, &types]( + column_index_t channel) { + channels.push_back(channel); + inversedChannels[channel] = channels.size() - 1; + names.push_back(inputType->nameOf(channel)); + types.push_back(inputType->childAt(channel)); + }; + + for (const auto& key : partitionKeys) { + auto channel = exec::exprToChannel(key.get(), inputType); + appendChannel(channel); + keyNames.insert(key->name()); + } + + for (const auto& key : sortingKeys) { + auto channel = exec::exprToChannel(key.get(), inputType); + appendChannel(channel); + keyNames.insert(key->name()); + } + + for (auto i = 0; i < size; ++i) { + if (keyNames.count(inputType->nameOf(i)) == 0) { + appendChannel(i); + } + } + + return std::make_tuple( + channels, inversedChannels, ROW(std::move(names), std::move(types))); +} + +// Returns a [start, end) slice of the 'types' vector. +std::vector +slice(const std::vector& types, int32_t start, int32_t end) { + std::vector result; + result.reserve(end - start); + for (auto i = start; i < end; ++i) { + result.push_back(types[i]); + } + return result; +} + +std::vector makeCompareFlags( + int32_t numPartitionKeys, + const std::vector& sortingOrders) { + std::vector compareFlags; + compareFlags.reserve(numPartitionKeys + sortingOrders.size()); + + for (auto i = 0; i < numPartitionKeys; ++i) { + compareFlags.push_back({}); + } + + for (const auto& order : sortingOrders) { + compareFlags.push_back( + {order.isNullsFirst(), order.isAscending(), false /*equalsOnly*/}); + } + + return compareFlags; +} + +} // namespace + +TablePartitionBuild::TablePartitionBuild( + const RowTypePtr& inputType, + std::vector partitionKeys, + std::vector sortingKeys, + std::vector sortingOrders, + memory::MemoryPool* pool, + common::PrefixSortConfig&& prefixSortConfig) + : pool_(pool), + inputType_(inputType), + compareFlags_{makeCompareFlags(partitionKeys.size(), sortingOrders)}, + prefixSortConfig_(prefixSortConfig), + decodedInputVectors_(inputType->size()), + sortedRows_(0, memory::StlAllocator(*pool)), + partitionStartRows_(0, memory::StlAllocator(*pool)) { + VELOX_CHECK_NOT_NULL(pool_); + std::tie(inputChannels_, inversedInputChannels_, inputType_) = + reorderInputChannels(inputType, partitionKeys, sortingKeys); + + const auto numPartitionKeys = partitionKeys.size(); + const auto numSortingKeys = sortingKeys.size(); + const auto numKeys = numPartitionKeys + numSortingKeys; + + data_ = std::make_unique( + slice(inputType_->children(), 0, numKeys), + slice(inputType_->children(), numKeys, inputType_->size()), + pool); + + for (auto i = 0; i < numPartitionKeys; ++i) { + partitionKeyInfo_.push_back(std::make_pair(i, core::SortOrder{true, true})); + } + + for (auto i = 0; i < numSortingKeys; ++i) { + sortKeyInfo_.push_back( + std::make_pair(numPartitionKeys + i, sortingOrders[i])); + } +} + +void TablePartitionBuild::addInput(RowVectorPtr input) { + for (auto i = 0; i < inputChannels_.size(); ++i) { + decodedInputVectors_[i].decode(*input->childAt(inputChannels_[i])); + } + + // Add all the rows into the RowContainer. + for (auto row = 0; row < input->size(); ++row) { + char* newRow = data_->newRow(); + + for (auto col = 0; col < inputChannels_.size(); ++col) { + data_->store(decodedInputVectors_[col], row, newRow, col); + } + } + numRows_ += input->size(); +} + +void TablePartitionBuild::noMoreInput() { + if (numRows_ == 0) { + return; + } + + // At this point we have seen all the input rows. The operator is + // being prepared to output rows now. + // To prepare the rows for output in SortWindowBuild they need to + // be separated into partitions and sort by ORDER BY keys within + // the partition. This will order the rows for getOutput(). + sortPartitions(); +} + +void TablePartitionBuild::sortPartitions() { + // This is a very inefficient but easy implementation to order the input rows + // by partition keys + sort keys. + // Sort the pointers to the rows in RowContainer (data_) instead of sorting + // the rows. + sortedRows_.resize(numRows_); + exec::RowContainerIterator iter; + data_->listRows(&iter, numRows_, sortedRows_.data()); + + exec::PrefixSort::sort( + data_.get(), compareFlags_, prefixSortConfig_, pool_, sortedRows_); + + computePartitionStartRows(); +} + +void TablePartitionBuild::computePartitionStartRows() { + partitionStartRows_.reserve(numRows_); + + // Using a sequential traversal to find changing partitions. + // This algorithm is inefficient and can be changed + // i) Use a binary search kind of strategy. + // ii) If we use a Hashtable instead of a full sort then the count + // of rows in the partition can be directly used. + partitionStartRows_.push_back(0); + + VELOX_CHECK_GT(sortedRows_.size(), 0); + + vector_size_t start = 0; + while (start < sortedRows_.size()) { + auto next = findNextPartitionStartRow(start); + partitionStartRows_.push_back(next); + start = next; + } +} + +bool TablePartitionBuild::compareRowsWithKeys( + const char* lhs, + const char* rhs, + const std::vector>& keys) { + if (lhs == rhs) { + return false; + } + for (auto& key : keys) { + if (auto result = data_->compare( + lhs, + rhs, + key.first, + {key.second.isNullsFirst(), key.second.isAscending(), false})) { + return result < 0; + } + } + return false; +} + +// Use double front and back search algorithm to find next partition start row. +// It is more efficient than linear or binary search. +// This algorithm is described at +// https://medium.com/@insomniocode/search-algorithm-double-front-and-back-20f5f28512e7 +vector_size_t TablePartitionBuild::findNextPartitionStartRow( + vector_size_t start) { + auto partitionCompare = [&](const char* lhs, const char* rhs) -> bool { + return compareRowsWithKeys(lhs, rhs, partitionKeyInfo_); + }; + + auto left = start; + auto right = left + 1; + auto lastPosition = sortedRows_.size(); + while (right < lastPosition) { + auto distance = 1; + for (; distance < lastPosition - left; distance *= 2) { + right = left + distance; + if (partitionCompare(sortedRows_[left], sortedRows_[right]) != 0) { + lastPosition = right; + break; + } + } + left += distance / 2; + right = left + 1; + } + return right; +} + +std::shared_ptr TablePartitionBuild::nextPartition() { + VELOX_CHECK( + !partitionStartRows_.empty(), "No table function partitions available"); + + currentPartition_++; + VELOX_CHECK_LE( + currentPartition_, + partitionStartRows_.size() - 2, + "All table function partitions consumed"); + + // There is partition data available now. + auto partitionSize = partitionStartRows_[currentPartition_ + 1] - + partitionStartRows_[currentPartition_]; + auto partition = folly::Range( + sortedRows_.data() + partitionStartRows_[currentPartition_], + partitionSize); + return std::make_shared( + data_.get(), partition, inversedInputChannels_); +} + +bool TablePartitionBuild::hasNextPartition() { + return partitionStartRows_.size() > 0 && + currentPartition_ < static_cast(partitionStartRows_.size() - 2); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/exec/TablePartitionBuild.h b/presto-native-execution/presto_cpp/main/tvf/exec/TablePartitionBuild.h new file mode 100644 index 0000000000000..d9f41cd306508 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/exec/TablePartitionBuild.h @@ -0,0 +1,154 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/exec/TableFunctionPartition.h" + +#include "velox/exec/PrefixSort.h" + +namespace facebook::presto::tvf { + +// This is a class used by the TableFunction Operator to maintain and provide +// table partitions for function evaluation at execution. + +class TablePartitionBuild { + public: + TablePartitionBuild( + const velox::RowTypePtr& inputType, + std::vector partitionKeys, + std::vector sortingKeys, + std::vector sortingOrders, + velox::memory::MemoryPool* pool, + velox::common::PrefixSortConfig&& prefixSortConfig); + + ~TablePartitionBuild() { + pool_->release(); + } + + void addInput(velox::RowVectorPtr input); + + void noMoreInput(); + + /// Returns true if a new Table function partition is available for the + /// Table function operator to consume. + bool hasNextPartition(); + + /// The Table function operator invokes this function to get the next + /// TableFunctionPartition for evaluation. The TableFunctionPartition + /// has APIs to access the underlying columns of Table Function Operator data. + /// Check hasNextPartition() before invoking this function. This function + /// fails if called when no partition is available. + std::shared_ptr nextPartition(); + + velox::vector_size_t numRows() { + return numRows_; + } + + /// Returns the average size of input rows in bytes stored in the data + /// container of the WindowBuild. + std::optional estimateRowSize() { + return data_->estimateRowSize(); + } + + private: + // Main sorting function loop done after all input rows are received + // by WindowBuild. + void sortPartitions(); + + // Function to compute the partitionStartRows_ structure. + // partitionStartRows_ is vector of the starting rows index + // of each partition in the data. This is an auxiliary + // structure that helps simplify the window function computations. + void computePartitionStartRows(); + + // Find the next partition start row from start. + velox::vector_size_t findNextPartitionStartRow(velox::vector_size_t start); + + bool compareRowsWithKeys( + const char* lhs, + const char* rhs, + const std::vector< + std::pair>& keys); + + velox::memory::MemoryPool* const pool_; + + /// Input column types in 'inputChannels_' order. + velox::RowTypePtr inputType_; + + // Compare flags for partition and sorting keys. Compare flags for partition + // keys are set to default values. Compare flags for sorting keys match + // sorting order specified in the plan node. + // + // Used to sort 'data_' while spilling and in Prefix sort. + const std::vector compareFlags_; + + // Config for Prefix-sort. + const velox::common::PrefixSortConfig prefixSortConfig_; + + /// Input columns in the order of: partition keys, sorting keys, the rest. + std::vector inputChannels_; + + /// The mapping from original input column index to the index after column + /// reordering. This is the inversed mapping of inputChannels_. + std::vector inversedInputChannels_; + + /// The RowContainer holds all the input rows in TablePartitionBuild. Columns + /// are already reordered according to inputChannels_. + std::unique_ptr data_; + + /// The decodedInputVectors_ are reused across addInput() calls to decode the + /// partition and sort keys for the above RowContainer. + std::vector decodedInputVectors_; + + /// Number of input rows. + velox::vector_size_t numRows_ = 0; + + /// The below 2 vectors represent the ChannelIndex of the partition keys + /// and the order by keys. These keyInfo are used for sorting by those + /// key combinations during the processing. partitionKeyInfo_ is used to + /// separate partitions in the rows. sortKeyInfo_ is used to identify peer + /// rows in a partition. + std::vector> + partitionKeyInfo_; + std::vector> + sortKeyInfo_; + + // allKeyInfo_ is a combination of (partitionKeyInfo_ and sortKeyInfo_). + // It is used to perform a full sorting of the input rows to be able to + // separate partitions and sort the rows in it. The rows are output in + // this order by the operator. + std::vector> + allKeyInfo_; + + // Vector of pointers to each input row in the data_ RowContainer. + // The rows are sorted by partitionKeys + sortKeys. This total + // ordering can be used to split partitions (with the correct + // order by) for the processing. + std::vector> sortedRows_; + + // This is a vector that gives the index of the start row + // (in sortedRows_) of each partition in the RowContainer data_. + // This auxiliary structure helps demarcate partitions. + std::vector< + velox::vector_size_t, + velox::memory::StlAllocator> + partitionStartRows_; + + // Current partition being output. Used to construct WindowPartitions + // during resetPartition. + velox::vector_size_t currentPartition_ = -1; +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tvf/functions/CMakeLists.txt new file mode 100644 index 0000000000000..5c92dd326deeb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/CMakeLists.txt @@ -0,0 +1,21 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library( + presto_tvf_functions + ExcludeColumns.cpp + RemoteAnn.cpp + Sequence.cpp + TableFunctionsRegistration.cpp + TestingTableFunctions.cpp +) + +target_link_libraries(presto_tvf_functions presto_tvf_spi) diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/ExcludeColumns.cpp b/presto-native-execution/presto_cpp/main/tvf/functions/ExcludeColumns.cpp new file mode 100644 index 0000000000000..89c78d8d73752 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/ExcludeColumns.cpp @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/spi/TableFunction.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; + +namespace { + +class ExcludeColumnsHandle : public TableFunctionHandle { + public: + ExcludeColumnsHandle(){}; + + std::string_view name() const override { + return "ExcludeColumnsHandle"; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = fmt::format("{}", name()); + return obj; + } + + static std::shared_ptr create( + const folly::dynamic& obj, + void* context) { + return std::make_shared(); + } + + static void registerSerDe() { + auto& registry = velox::DeserializationWithContextRegistryForSharedPtr(); + registry.Register("ExcludeColumnsHandle", ExcludeColumnsHandle::create); + } + + private: +}; + +class ExcludeColumnsAnalysis : public TableFunctionAnalysis { + public: + explicit ExcludeColumnsAnalysis() : TableFunctionAnalysis() {} +}; + +static const std::string TABLE_ARGUMENT_NAME = "INPUT"; +static const std::string DESCRIPTOR_ARGUMENT_NAME = "COLUMNS"; + +class ExcludeColumnsDataProcessor : public TableFunctionDataProcessor { + public: + explicit ExcludeColumnsDataProcessor( + const ExcludeColumnsHandle* handle, + memory::MemoryPool* pool) + : TableFunctionDataProcessor("exclude_columns", pool, nullptr), + handle_(handle) {} + + std::shared_ptr apply( + const std::vector& input) override { + auto inputTable = input.at(0); + auto numRows = inputTable->size(); + if (numRows == 0) { + return std::make_shared( + TableFunctionResult::TableFunctionState::kFinished); + } + + // Get a projection of non-excluded columns from inputTable. + return std::make_shared(true, std::move(inputTable)); + } + + private: + const ExcludeColumnsHandle* handle_; +}; + +class ExcludeColumns : public TableFunction { + public: + static std::unique_ptr analyze( + const std::unordered_map>& args) { + VELOX_CHECK_GT( + args.count(DESCRIPTOR_ARGUMENT_NAME), 0, "COLUMNS arg not found"); + VELOX_CHECK_GT(args.count(TABLE_ARGUMENT_NAME), 0, "INPUT arg not found"); + + auto excludedColumnsArg = args.at(DESCRIPTOR_ARGUMENT_NAME); + VELOX_CHECK(excludedColumnsArg, "COLUMNS arg is NULL"); + auto excludedColumnsDesc = + std::dynamic_pointer_cast(excludedColumnsArg); + VELOX_CHECK(excludedColumnsDesc, "COLUMNS arg not a descriptor"); + + auto inputArg = args.at(TABLE_ARGUMENT_NAME); + VELOX_CHECK(inputArg, "INPUT arg is NULL"); + auto inputTableArg = std::dynamic_pointer_cast(inputArg); + VELOX_CHECK(inputTableArg, "INPUT arg not a table"); + + // Validate that each excluded column is found in the input table + // and remove it. + auto inputColumns = inputTableArg->rowType()->names(); + std::unordered_set inputColumnsSet; + for (const auto& col : inputColumns) { + inputColumnsSet.insert(col); + } + for (const auto& col : excludedColumnsDesc->names()) { + VELOX_CHECK_GT( + inputColumnsSet.count(col), 0, "COLUMN {} not found in INPUT", col); + } + std::unordered_set excludeColumnsSet; + for (const auto& col : excludedColumnsDesc->names()) { + excludeColumnsSet.insert(col); + } + + std::vector returnNames; + std::vector returnTypes; + std::unordered_map> + requiredColumns; + requiredColumns.reserve(1); + std::vector requiredColsList; + for (column_index_t i = 0; i < inputColumns.size(); i++) { + if (excludeColumnsSet.count(inputColumns.at(i)) == 0) { + // This column is not in the exclude_columns list and so is returned in + // the output. + returnNames.push_back(inputColumns.at(i)); + returnTypes.push_back(inputTableArg->rowType()->childAt(i)); + requiredColsList.push_back(i); + } + } + requiredColumns.insert({TABLE_ARGUMENT_NAME, requiredColsList}); + auto analysis = std::make_unique(); + analysis->tableFunctionHandle_ = std::make_shared(); + analysis->returnType_ = + std::make_shared(returnNames, returnTypes); + analysis->requiredColumns_ = requiredColumns; + return std::move(analysis); + } + + velox::RowTypePtr returnType_; + const SelectivityVector inputSelections_; +}; +} // namespace + +void registerExcludeColumns(const std::string& name) { + TableArgumentSpecList argSpecs; + argSpecs.insert( + std::make_shared(TABLE_ARGUMENT_NAME, true, true, false)); + argSpecs.insert(std::make_shared( + DESCRIPTOR_ARGUMENT_NAME, Descriptor({"columns"}), true)); + registerTableFunction( + name, + argSpecs, + std::make_shared(), + ExcludeColumns::analyze, + [](const TableFunctionHandlePtr& handle, + memory::MemoryPool* pool, + HashStringAllocator* /*stringAllocator*/, + const core::QueryConfig& /*queryConfig*/) + -> std::unique_ptr { + auto excludeHandle = dynamic_cast(handle.get()); + return std::make_unique(excludeHandle, pool); + }); + ExcludeColumnsHandle::registerSerDe(); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/RemoteAnn.cpp b/presto-native-execution/presto_cpp/main/tvf/functions/RemoteAnn.cpp new file mode 100644 index 0000000000000..9313293f930f5 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/RemoteAnn.cpp @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/spi/TableFunction.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::tvf { + +namespace { + +class RemoteAnnAnalysis : public TableFunctionAnalysis { + public: + explicit RemoteAnnAnalysis() : TableFunctionAnalysis() {} +}; + +class RemoteAnnDataProcessor : public TableFunctionDataProcessor { + public: + explicit RemoteAnnDataProcessor(velox::memory::MemoryPool* pool) + : TableFunctionDataProcessor("remote_ann", pool, nullptr) {} + + std::shared_ptr apply( + const std::vector& input) override { + return std::make_shared( + TableFunctionResult::TableFunctionState::kFinished); + } +}; +} // namespace + +void registerRemoteAnn(const std::string& name) { + registerTableFunction( + name, + {}, + std::make_shared(), + [](const std::unordered_map>& args) + -> std::unique_ptr { + return std::make_unique(); + }, + [](const std::shared_ptr& handle, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* /*stringAllocator*/, + const velox::core::QueryConfig& /*queryConfig*/) + -> std::unique_ptr { + return std::make_unique(pool); +}); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/Sequence.cpp b/presto-native-execution/presto_cpp/main/tvf/functions/Sequence.cpp new file mode 100644 index 0000000000000..0e4511c0d596b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/Sequence.cpp @@ -0,0 +1,259 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/spi/TableFunction.h" +#include "velox/common/base/Exceptions.h" +#include "velox/vector/FlatVector.h" + +using namespace facebook::velox; + +namespace facebook::presto::tvf { + +namespace { + +static const std::string START_ARGUMENT_NAME = "START"; +static const std::string STOP_ARGUMENT_NAME = "STOP"; +static const std::string STEP_ARGUMENT_NAME = "STEP"; + +class SequenceHandle : public TableFunctionHandle { + public: + SequenceHandle(int64_t start, int64_t stop, int64_t step) + : start_(start), stop_(stop), step_(step){}; + + std::string_view name() const override { + return "SequenceHandle"; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = fmt::format("{}", name()); + obj["start"] = start_; + obj["stop"] = stop_; + obj["step"] = step_; + return obj; + } + + static std::shared_ptr create( + const folly::dynamic& obj, + void* context) { + return std::make_shared( + obj["start"].asInt(), obj["stop"].asInt(), obj["step"].asInt()); + } + + static void registerSerDe() { + auto& registry = velox::DeserializationWithContextRegistryForSharedPtr(); + registry.Register("SequenceHandle", SequenceHandle::create); + } + + int64_t start() const { + return start_; + } + + int64_t stop() const { + return stop_; + } + + int64_t step() const { + return step_; + } + + private: + int64_t start_; + int64_t stop_; + int64_t step_; +}; + +class SequenceSplitHandle : public TableSplitHandle { + public: + SequenceSplitHandle(int64_t start, int64_t numSteps) + : start_(start), numSteps_(numSteps){}; + + std::string_view name() const override { + return "SequenceSplitHandle"; + } + + int64_t start() const { + return start_; + } + + int64_t numSteps() const { + return numSteps_; + } + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = fmt::format("{}", name()); + obj["start"] = start_; + obj["numSteps"] = numSteps_; + return obj; + } + + static std::shared_ptr create( + const folly::dynamic& obj, + void* context) { + return std::make_shared( + obj["start"].asInt(), obj["numSteps"].asInt()); + } + + static void registerSerDe() { + auto& registry = velox::DeserializationWithContextRegistryForSharedPtr(); + registry.Register("SequenceSplitHandle", SequenceSplitHandle::create); + } + + private: + int64_t start_; + int64_t numSteps_; +}; + +class SequenceAnalysis : public TableFunctionAnalysis { + public: + explicit SequenceAnalysis() : TableFunctionAnalysis() {} +}; + +class Sequence : public TableFunctionSplitProcessor { + public: + explicit Sequence( + velox::memory::MemoryPool* pool, + const SequenceHandle* handle) + : TableFunctionSplitProcessor("sequence", pool, nullptr), + step_(handle->step()) {} + + static std::unique_ptr analyze( + const std::unordered_map>& args) { + VELOX_CHECK_GT(args.count(START_ARGUMENT_NAME), 0, "START arg not found"); + VELOX_CHECK_GT(args.count(STOP_ARGUMENT_NAME), 0, "STOP arg not found"); + + auto startArg = args.at(START_ARGUMENT_NAME); + VELOX_CHECK(startArg, "START arg is NULL"); + auto startPtr = std::dynamic_pointer_cast(startArg); + VELOX_CHECK(startPtr, "START arg is not a scalar"); + auto startVal = + startPtr->value()->template as>()->valueAt(0); + + auto stopArg = args.at(STOP_ARGUMENT_NAME); + VELOX_CHECK(stopArg, "STOP arg is NULL"); + auto stopPtr = std::dynamic_pointer_cast(stopArg); + VELOX_CHECK(stopPtr, "STOP arg is not a scalar"); + auto stopVal = + stopPtr->value()->template as>()->valueAt(0); + + auto stepArg = args.at(STEP_ARGUMENT_NAME); + VELOX_CHECK(stepArg, "STEP arg is NULL"); + auto stepPtr = std::dynamic_pointer_cast(stepArg); + VELOX_CHECK(stepPtr, "STEP arg is not a scalar"); + auto stepVal = + stepPtr->value()->template as>()->valueAt(0); + + auto handle = std::make_shared(startVal, stopVal, stepVal); + auto analysis = std::make_unique(); + analysis->tableFunctionHandle_ = handle; + return analysis; + } + + std::shared_ptr apply( + const std::shared_ptr& split) override { + auto sequenceSplit = + std::dynamic_pointer_cast(split); + VELOX_CHECK(sequenceSplit, "Split was not a SequenceSplitHandle"); + + if (processed_) { + processed_ = false; + return std::make_shared( + TableFunctionResult::TableFunctionState::kFinished); + } + + VELOX_CHECK(!processed_); + + auto start = sequenceSplit->start(); + auto numSteps = sequenceSplit->numSteps(); + auto sequenceCol = + BaseVector::create>(BIGINT(), numSteps, pool_); + auto rawValues = sequenceCol->values()->asMutable(); + for (auto i = 0; i < numSteps; i++) { + rawValues[i] = start + i * step_; + } + + auto result = + BaseVector::create(ROW({BIGINT()}), numSteps, pool_); + result->childAt(0) = sequenceCol; + + processed_ = true; + return std::make_shared(true, result); + } + + static std::vector getSplits( + const TableFunctionHandlePtr& handle) { + static const int64_t kMaxSteps = 1000000; + auto sequenceHandle = + std::dynamic_pointer_cast(handle); + auto start = sequenceHandle->start(); + auto stop = sequenceHandle->stop(); + auto step = sequenceHandle->step(); + + auto numSteps = (stop - start) / step + 1; + + std::vector splits = {}; + splits.reserve((numSteps / kMaxSteps) + 1); + auto splitStart = start; + while (numSteps > 0) { + auto splitSteps = numSteps < kMaxSteps ? numSteps : kMaxSteps; + auto sequenceSplit = + std::make_shared(splitStart, splitSteps); + splits.push_back(sequenceSplit); + numSteps -= kMaxSteps; + splitStart = start + (kMaxSteps * step); + } + return splits; + } + + private: + int64_t step_; + bool processed_; +}; +} // namespace + +void registerSequence(const std::string& name) { + TableArgumentSpecList argSpecs; + argSpecs.insert(std::make_shared( + START_ARGUMENT_NAME, BIGINT(), true)); + argSpecs.insert(std::make_shared( + STOP_ARGUMENT_NAME, BIGINT(), true)); + // TODO : Figure how to make this an optional argument. + argSpecs.insert(std::make_shared( + STEP_ARGUMENT_NAME, BIGINT(), true)); + + std::vector names = {"sequential_number"}; + std::vector types = {BIGINT()}; + auto returnType = std::make_shared(names, types); + + registerTableFunction( + name, + argSpecs, + std::make_shared(returnType), + Sequence::analyze, + TableFunction::defaultCreateDataProcessor, + [](const TableFunctionHandlePtr& handle, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* /*stringAllocator*/, + const velox::core::QueryConfig& /*queryConfig*/) + -> std::unique_ptr { + auto sequenceHandle = dynamic_cast(handle.get()); + return std::make_unique(pool, sequenceHandle); + }, + Sequence::getSplits); + SequenceHandle::registerSerDe(); + SequenceSplitHandle::registerSerDe(); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/TableFunctionsRegistration.cpp b/presto-native-execution/presto_cpp/main/tvf/functions/TableFunctionsRegistration.cpp new file mode 100644 index 0000000000000..3a3e0380c6037 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/TableFunctionsRegistration.cpp @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/functions/TableFunctionsRegistration.h" +#include "presto_cpp/main/tvf/functions/TestingTableFunctions.h" + +namespace facebook::presto::tvf { + +extern void registerExcludeColumns(const std::string& name); +extern void registerSequence(const std::string& name); +extern void registerRemoteAnn(const std::string& name); + +void registerAllTableFunctions(const std::string& prefix) { + registerExcludeColumns(prefix + "exclude_columns"); + registerSequence(prefix + "sequence"); + registerRemoteAnn(prefix + "remoteAnn"); + registerRepeatFunction(prefix + "repeat_table_function"); + registerIdentityFunction(prefix + "identity_table_function"); + registerSimpleTableFunction(prefix + "simple_table_function"); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/TableFunctionsRegistration.h b/presto-native-execution/presto_cpp/main/tvf/functions/TableFunctionsRegistration.h new file mode 100644 index 0000000000000..0396ae4a64dfa --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/TableFunctionsRegistration.h @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/spi/TableFunction.h" + +namespace facebook::presto::tvf { + +void registerAllTableFunctions(const std::string& prefix = ""); + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/TestingTableFunctions.cpp b/presto-native-execution/presto_cpp/main/tvf/functions/TestingTableFunctions.cpp new file mode 100644 index 0000000000000..d6a12565823f0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/TestingTableFunctions.cpp @@ -0,0 +1,166 @@ +/* +* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/functions/TestingTableFunctions.h" + +#include "velox/vector/BaseVector.h" +#include "velox/vector/ConstantVector.h" + +using namespace facebook::velox; + +namespace facebook::presto::tvf { + +std::unique_ptr SimpleTableFunction::analyze( + const std::unordered_map>& args) { + std::vector returnNames; + std::vector returnTypes; + + const auto arg = std::dynamic_pointer_cast(args.at("COLUMN")); + const auto val = arg->value()->as>()->valueAt(0); + + returnNames.push_back(val); + returnTypes.push_back(BOOLEAN()); + + auto analysis = std::make_unique(); + analysis->tableFunctionHandle_ = std::make_shared(); + analysis->returnType_ = std::make_shared(returnNames, returnTypes); + return analysis; +} + +void registerSimpleTableFunction(const std::string& name) { + TableArgumentSpecList argSpecs; + argSpecs.insert( + std::make_shared("COLUMN", VARCHAR(), true)); + argSpecs.insert( + std::make_shared( + "IGNORED", BIGINT(), false)); + + registerTableFunction( + name, + argSpecs, + std::make_shared(), + SimpleTableFunction::analyze); +} + +std::shared_ptr IdentityDataProcessor::apply( + const std::vector& input) { + auto inputTable = input.at(0); + auto numRows = inputTable->size(); + if (numRows == 0) { + return std::make_shared(TableFunctionResult::TableFunctionState::kFinished); + } + return std::make_shared(true, std::move(inputTable)); +} + +std::unique_ptr IdentityFunction::analyze( + const std::unordered_map>& args) { + auto input = std::dynamic_pointer_cast(args.at("INPUT")); + std::vector returnNames = input->rowType()->names(); + std::vector returnTypes; + std::vector requiredColsList; + for (size_t i = 0; i < returnNames.size(); i++) { + returnTypes.push_back(input->rowType()->childAt(i)); + requiredColsList.push_back(i); + } + RequiredColumnsMap requiredColumns; + requiredColumns.emplace("INPUT", requiredColsList); + + auto analysis = std::make_unique(); + analysis->tableFunctionHandle_ = std::make_shared(); + analysis->returnType_ = std::make_shared(returnNames, returnTypes); + analysis->requiredColumns_ = requiredColumns; + return analysis; +} + +void registerIdentityFunction(const std::string& name) { + TableArgumentSpecList argSpecs; + argSpecs.insert(std::make_shared("INPUT", true, false, false)); + registerTableFunction( + name, + argSpecs, + std::make_shared(), + IdentityFunction::analyze, + [](const TableFunctionHandlePtr& handle, + memory::MemoryPool* pool, + HashStringAllocator* stringAllocator, + const velox::core::QueryConfig& config) + -> std::unique_ptr { + return std::make_unique( + dynamic_cast(handle.get()), pool); + }); +} + +std::shared_ptr RepeatFunctionDataProcessor::apply( + const std::vector& input) { + auto inputTable = input.at(0); + auto numRows = inputTable->size(); + if (numRows == 0) { + return std::make_shared(TableFunctionResult::TableFunctionState::kFinished); + } + + RowVectorPtr outputTable = RowVector::createEmpty(inputTable->rowType(), pool()); + auto count = handle_->count(); + outputTable->resize(numRows * count); + for (int i = 0; i < count; i++) { + outputTable->copy(inputTable.get(), i * numRows, 0, numRows); + } + + return std::make_shared(true, std::move(outputTable)); +} + +std::unique_ptr RepeatFunction::analyze( + const std::unordered_map>& args) { + auto input = std::dynamic_pointer_cast(args.at("INPUT")); + + auto countArg = args.at("COUNT"); + auto countPtr = std::dynamic_pointer_cast(countArg); + + auto count = countPtr->value()->as>()->valueAt(0); + + std::vector returnNames = input->rowType()->names(); + std::vector returnTypes; + std::vector requiredColsList; + for (size_t i = 0; i < returnNames.size(); i++) { + returnTypes.push_back(input->rowType()->childAt(i)); + requiredColsList.push_back(i); + } + RequiredColumnsMap requiredColumns; + requiredColumns.emplace("INPUT", requiredColsList); + auto analysis = std::make_unique(); + analysis->tableFunctionHandle_ = std::make_shared(count); + analysis->returnType_ = std::make_shared(returnNames, returnTypes); + analysis->requiredColumns_ = requiredColumns; + return analysis; +} + +void registerRepeatFunction(const std::string& name) { + TableArgumentSpecList argSpecs; + argSpecs.insert(std::make_shared("INPUT", true, false, false)); + argSpecs.insert(std::make_shared("COUNT", BIGINT(), true)); + registerTableFunction( + name, + argSpecs, + std::make_shared(), + RepeatFunction::analyze, + [](const TableFunctionHandlePtr& handle, + memory::MemoryPool* pool, + HashStringAllocator* stringAllocator, + const velox::core::QueryConfig& config) + -> std::unique_ptr { + return std::make_unique( + dynamic_cast(handle.get()), pool); + }); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/functions/TestingTableFunctions.h b/presto-native-execution/presto_cpp/main/tvf/functions/TestingTableFunctions.h new file mode 100644 index 0000000000000..367b629f84e78 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/functions/TestingTableFunctions.h @@ -0,0 +1,155 @@ +/* +* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/tvf/spi/TableFunction.h" + +using namespace facebook::velox; + +namespace facebook::presto::tvf { + +class SimpleTableFunctionHandle : public TableFunctionHandle { + public: + std::string_view name() const override { + return "SimpleTableFunctionHandle"; + }; + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = fmt::format("{}", name()); + return obj; + }; + + static std::shared_ptr create( + const folly::dynamic& obj, void* context) { + return std::shared_ptr(); + }; + + static void registerSerDe() { + auto& registry = velox::DeserializationWithContextRegistryForSharedPtr(); + registry.Register("SimpleTableFunctionHandle", create); + } +}; + +class SimpleTableFunctionAnalysis : public TableFunctionAnalysis {}; + +class SimpleTableFunction final : public TableFunction { + public: + static std::unique_ptr analyze( + const std::unordered_map>& args); +}; + +void registerSimpleTableFunction(const std::string& name); + +class IdentityFunctionHandle : public TableFunctionHandle { + std::string_view name() const override { + return "IdentityFunctionHandle"; + }; + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = fmt::format("{}", name()); + return obj; + }; + + static std::shared_ptr create( + const folly::dynamic& obj, void* context) { + return std::shared_ptr(); + }; + + static void registerSerDe() { + auto& registry = velox::DeserializationWithContextRegistryForSharedPtr(); + registry.Register("IdentityFunctionHandle", create); + } +}; + +class IdentityFunctionAnalysis : public TableFunctionAnalysis {}; + +class IdentityDataProcessor : public TableFunctionDataProcessor { + public: + explicit IdentityDataProcessor(const IdentityFunctionHandle* handle, memory::MemoryPool* pool) + : TableFunctionDataProcessor("identity", pool, nullptr), handle_(handle) {} + + std::shared_ptr apply( + const std::vector& input) override; + + private: + const IdentityFunctionHandle* handle_; +}; + +class IdentityFunction final : public TableFunction { + public: + static std::unique_ptr analyze( + const std::unordered_map>& args); +}; + +void registerIdentityFunction(const std::string& name); + +class RepeatFunctionHandle : public TableFunctionHandle { + public: + explicit RepeatFunctionHandle(int64_t count) : count_(count) {} + + std::string_view name() const override { + return "RepeatFunctionHandle"; + }; + + folly::dynamic serialize() const override { + folly::dynamic obj = folly::dynamic::object; + obj["name"] = fmt::format("{}", name()); + obj["count"] = count_; + return obj; + }; + + static std::shared_ptr create( + const folly::dynamic& obj, void* context) { + return std::make_shared(obj["count"].asInt()); + }; + + static void registerSerDe() { + auto& registry = velox::DeserializationWithContextRegistryForSharedPtr(); + registry.Register("RepeatFunctionHandle", create); + } + + int64_t count() const { + return count_; + } + + private: + int64_t count_; +}; + +class RepeatFunctionDataProcessor : public TableFunctionDataProcessor { + public: + RepeatFunctionDataProcessor( + const RepeatFunctionHandle* handle, + velox::memory::MemoryPool* pool) + : TableFunctionDataProcessor("repeat", pool, nullptr), handle_(handle) {} + + std::shared_ptr apply( + const std::vector& input) override; + + private: + const RepeatFunctionHandle* handle_; +}; + +class RepeatFunctionAnalysis : public TableFunctionAnalysis {}; + +class RepeatFunction final : public TableFunction { + public: + static std::unique_ptr analyze( + const std::unordered_map>& args); +}; + +void registerRepeatFunction(const std::string& name); +} diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/Argument.h b/presto-native-execution/presto_cpp/main/tvf/spi/Argument.h new file mode 100644 index 0000000000000..12ee2d2c60dd2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/Argument.h @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/Expressions.h" + +namespace facebook::presto::tvf { + +class Argument { + public: + Argument() {} + + virtual ~Argument() = default; + + protected: +}; + +class ArgumentSpecification { + public: + ArgumentSpecification(std::string name, bool required) + : name_(std::move(name)), required_(required){}; + + virtual ~ArgumentSpecification() = default; + + const std::string& name() const { + return name_; + } + + const bool required() const { + return required_; + } + + private: + const std::string name_; + const bool required_; + // TODO : Add default value. +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tvf/spi/CMakeLists.txt new file mode 100644 index 0000000000000..ccffec24380ab --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/CMakeLists.txt @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_tvf_spi TableFunction.cpp) + +target_link_libraries(presto_tvf_spi velox_core velox_vector) diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/Descriptor.h b/presto-native-execution/presto_cpp/main/tvf/spi/Descriptor.h new file mode 100644 index 0000000000000..726b9d35bbf6a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/Descriptor.h @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/spi/Argument.h" + +#include "velox/core/Expressions.h" + +namespace facebook::presto::tvf { + +class Descriptor : public Argument { + public: + Descriptor(std::vector names) + : type_(DescriptorType::kOnlyName), names_(std::move(names)) {} + + Descriptor(std::vector names, std::vector types) + : type_(DescriptorType::kNameType), + names_(std::move(names)), + types_(std::move(types)) { + VELOX_CHECK_EQ(names_.size(), types_.size()); + fields_.reserve(names_.size()); + for (velox::column_index_t i = 0; i < names_.size(); i++) { + fields_.push_back(std::make_shared( + types_.at(i), names_.at(i))); + } + } + + const std::vector fields() const { + VELOX_CHECK(type_ == DescriptorType::kNameType); + return fields_; + } + + const std::vector names() const { + return names_; + } + + const std::vector types() const { + return types_; + } + + private: + enum class DescriptorType { + kOnlyName, + kNameType, + }; + DescriptorType type_; + + std::vector names_; + std::vector types_; + + std::vector fields_; +}; + +class DescriptorArgumentSpecification : public ArgumentSpecification { + public: + DescriptorArgumentSpecification( + std::string name, + Descriptor descriptor_, + bool required) + : ArgumentSpecification(name, required), + descriptor_(std::move(descriptor_)){}; + + const Descriptor descriptor() const { + return descriptor_; + } + + private: + const Descriptor descriptor_; +}; + +using DescriptorPtr = std::shared_ptr; +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/ReturnTypeSpecification.h b/presto-native-execution/presto_cpp/main/tvf/spi/ReturnTypeSpecification.h new file mode 100644 index 0000000000000..ff4f276fd78cd --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/ReturnTypeSpecification.h @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/spi/Descriptor.h" + +namespace facebook::presto::tvf { + +class ReturnTypeSpecification { + public: + enum class ReturnType { kGenericTable, kDescribedTable }; + + ReturnTypeSpecification(ReturnType returnType) : returnType_(returnType){}; + + ReturnType returnType() const { + return returnType_; + } + + virtual ~ReturnTypeSpecification() = default; + + private: + ReturnType returnType_; +}; +using ReturnSpecPtr = std::shared_ptr; + +class GenericTableReturnType : public ReturnTypeSpecification { + public: + GenericTableReturnType() + : ReturnTypeSpecification( + ReturnTypeSpecification::ReturnType::kGenericTable){}; +}; + +class DescribedTableReturnType : public ReturnTypeSpecification { + public: + DescribedTableReturnType(DescriptorPtr descriptor) + : ReturnTypeSpecification( + ReturnTypeSpecification::ReturnType::kDescribedTable), + descriptor_(std::move(descriptor)) {} + + const DescriptorPtr descriptor() const { + return descriptor_; + } + + private: + DescriptorPtr descriptor_; +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/ScalarArgument.h b/presto-native-execution/presto_cpp/main/tvf/spi/ScalarArgument.h new file mode 100644 index 0000000000000..6fed6ce2d6cd1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/ScalarArgument.h @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/spi/Argument.h" + +#include "velox/core/Expressions.h" + +namespace facebook::presto::tvf { + +class ScalarArgument : public Argument { + public: + ScalarArgument(velox::TypePtr type, velox::VectorPtr value) + : type_(std::move(type)), constantValue_(std::move(value)) {} + + const velox::TypePtr rowType() const { + return type_; + } + + const velox::VectorPtr value() const { + return constantValue_; + } + + private: + const velox::TypePtr type_; + const velox::VectorPtr constantValue_; +}; + +class ScalarArgumentSpecification : public ArgumentSpecification { + public: + ScalarArgumentSpecification( + std::string name, + velox::TypePtr type, + bool required) + : ArgumentSpecification(name, required), type_(std::move(type)){}; + + const velox::TypePtr rowType() const { + return type_; + } + + private: + const velox::TypePtr type_; +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/TableArgument.h b/presto-native-execution/presto_cpp/main/tvf/spi/TableArgument.h new file mode 100644 index 0000000000000..f74ab21006c74 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/TableArgument.h @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/spi/Argument.h" + +#include "velox/core/Expressions.h" +#include "velox/core/PlanNode.h" + +namespace facebook::presto::tvf { + +class TableArgument : public Argument { + public: + TableArgument(velox::RowTypePtr type, + std::vector partitionKeys = {}, + std::vector sortingKeys = {}, + std::vector sortingOrders = {}) + : rowType_(std::move(type)), + partitionKeys_(std::move(partitionKeys)), + sortingKeys_(std::move(sortingKeys)), + sortingOrders_(std::move(sortingOrders)) { + VELOX_CHECK_EQ( + sortingKeys_.size(), + sortingOrders_.size(), + "Number of sorting keys must be equal to the number of sorting orders"); + + std::unordered_set keyNames; + for (const auto& key : partitionKeys_) { + VELOX_USER_CHECK( + keyNames.insert(key->name()).second, + "Partitioning keys must be unique. Found duplicate key: {}", + key->name()); + } + + for (const auto& key : sortingKeys_) { + VELOX_USER_CHECK( + keyNames.insert(key->name()).second, + "Sorting keys must be unique and not overlap with partitioning keys. Found duplicate key: {}", + key->name()); + } + } + + velox::RowTypePtr rowType() const { + return rowType_; + } + + const std::vector& partitionKeys() const { + return partitionKeys_; + } + + const std::vector& sortingKeys() const { + return sortingKeys_; + } + + const std::vector& sortingOrders() const { + return sortingOrders_; + } + + private: + const velox::RowTypePtr rowType_; + const std::vector partitionKeys_; + const std::vector sortingKeys_; + const std::vector sortingOrders_; +}; + +class TableArgumentSpecification : public ArgumentSpecification { + public: + TableArgumentSpecification(std::string name, bool rowSemantics, bool pruneWhenEmpty, bool passThroughColumns) + : ArgumentSpecification(name, true), + rowSemantics_(rowSemantics), + pruneWhenEmpty_(pruneWhenEmpty), + passThroughColumns_(passThroughColumns) {}; + + bool rowSemantics() const { + return rowSemantics_; + } + + bool pruneWhenEmpty() const { + return pruneWhenEmpty_; + } + + bool passThroughColumns() const { + return passThroughColumns_; + } + + private: + const bool rowSemantics_; + const bool pruneWhenEmpty_; + const bool passThroughColumns_; +}; + +using TableArgumentSpecList = + std::unordered_set>; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/TableFunction.cpp b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunction.cpp new file mode 100644 index 0000000000000..678d7e01da708 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunction.cpp @@ -0,0 +1,132 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/spi/TableFunction.h" + +#include "velox/expression/FunctionSignature.h" + +namespace facebook::presto::tvf { + +using namespace facebook::velox; + +TableFunctionMap& tableFunctions() { + static TableFunctionMap functions; + return functions; +} + +namespace { +std::optional getTableFunctionEntry( + const std::string& name) { + auto& functionsMap = tableFunctions(); + auto it = functionsMap.find(name); + if (it != functionsMap.end()) { + return &it->second; + } + + return std::nullopt; +} +} // namespace + +bool registerTableFunction( + const std::string& name, + TableArgumentSpecList argumentsSpec, + ReturnSpecPtr returnSpec, + TableFunctionAnalyzer analyzer, + TableFunctionDataProcessorFactory dataProcessorfactory, + TableFunctionSplitProcessorFactory splitProcessorfactory, + TableFunctionSplitGenerator splitGenerator) { + auto sanitizedName = exec::sanitizeName(name); + tableFunctions().insert( + {sanitizedName, + {std::move(argumentsSpec), + std::move(returnSpec), + std::move(analyzer), + std::move(dataProcessorfactory), + std::move(splitProcessorfactory), + std::move(splitGenerator)}}); + return true; +} + +ReturnSpecPtr getTableFunctionReturnType(const std::string& name) { + const auto sanitizedName = exec::sanitizeName(name); + if (auto func = getTableFunctionEntry(sanitizedName)) { + return func.value()->returnSpec; + } else { + VELOX_USER_FAIL("ReturnTypeSpecification not found for function: {}", name); + } +} + +TableArgumentSpecList getTableFunctionArgumentSpecs(const std::string& name) { + const auto sanitizedName = exec::sanitizeName(name); + if (auto func = getTableFunctionEntry(sanitizedName)) { + return func.value()->argumentsSpec; + } else { + VELOX_USER_FAIL("Arguments Specification not found for function: {}", name); + } +} + +std::unique_ptr TableFunction::analyze( + const std::string& name, + const std::unordered_map>& args) { + // Lookup the function in the new registry first. + if (auto func = getTableFunctionEntry(name)) { + return func.value()->analyzer(args); + } + + VELOX_USER_FAIL("Table function not registered: {}", name); +} + +std::unique_ptr TableFunction::createDataProcessor( + const std::string& name, + const std::shared_ptr& handle, + memory::MemoryPool* pool, + HashStringAllocator* stringAllocator, + const core::QueryConfig& config) { + // Lookup the function in the new registry first. + if (auto func = getTableFunctionEntry(name)) { + return func.value()->dataProcessorFactory( + handle, pool, stringAllocator, config); + } + + VELOX_USER_FAIL("Table function not registered: {}", name); +} + +std::unique_ptr +TableFunction::createSplitProcessor( + const std::string& name, + const std::shared_ptr& handle, + memory::MemoryPool* pool, + HashStringAllocator* stringAllocator, + const core::QueryConfig& config) { + // Lookup the function in the new registry first. + if (auto func = getTableFunctionEntry(name)) { + return func.value()->splitProcessorFactory( + handle, pool, stringAllocator, config); + } + + VELOX_USER_FAIL("Table function not registered: {}", name); +} + +std::vector TableFunction::getSplits( + const std::string& name, + const TableFunctionHandlePtr& handle) { + // Lookup the function in the new registry first. + if (auto func = getTableFunctionEntry(name)) { + return func.value()->splitGenerator(handle); + } + + VELOX_USER_FAIL("Table function not registered: {}", name); +} + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/TableFunction.h b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunction.h new file mode 100644 index 0000000000000..69f6c953f4558 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunction.h @@ -0,0 +1,204 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/spi/Descriptor.h" +#include "presto_cpp/main/tvf/spi/ReturnTypeSpecification.h" +#include "presto_cpp/main/tvf/spi/ScalarArgument.h" +#include "presto_cpp/main/tvf/spi/TableArgument.h" +#include "presto_cpp/main/tvf/spi/TableFunctionAnalysis.h" +#include "presto_cpp/main/tvf/spi/TableFunctionResult.h" + +#include "velox/common/memory/HashStringAllocator.h" +#include "velox/core/QueryConfig.h" +#include "velox/vector/BaseVector.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::presto::tvf { + +class TableFunctionDataProcessor { + public: + explicit TableFunctionDataProcessor( + const std::string& name, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* stringAllocator) + : name_(name), pool_(pool), stringAllocator_(stringAllocator) {} + + virtual ~TableFunctionDataProcessor() = default; + + velox::memory::MemoryPool* pool() const { + return pool_; + } + + const velox::HashStringAllocator* stringAllocator() const { + return stringAllocator_; + } + + const std::string name() const { + return name_; + } + + virtual std::shared_ptr apply( + const std::vector& input) { + VELOX_NYI(" TableFunction::apply() for input vector is not implemented"); + } + + protected: + const std::string name_; + velox::memory::MemoryPool* pool_; + velox::HashStringAllocator* const stringAllocator_; +}; + +class TableFunctionSplitProcessor { + public: + explicit TableFunctionSplitProcessor( + const std::string& name, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* stringAllocator) + : pool_(pool), stringAllocator_(stringAllocator) {} + + virtual ~TableFunctionSplitProcessor() = default; + + velox::memory::MemoryPool* pool() const { + return pool_; + } + + const velox::HashStringAllocator* stringAllocator() const { + return stringAllocator_; + } + + const std::string name() const { + return name_; + } + + virtual std::shared_ptr apply( + const TableSplitHandlePtr& split) { + VELOX_NYI(" TableFunction::apply() for split is not implemented"); + } + + protected: + const std::string name_; + velox::memory::MemoryPool* pool_; + velox::HashStringAllocator* const stringAllocator_; +}; + +class TableFunction { + public: + explicit TableFunction() {}; + + virtual ~TableFunction() = default; + + static std::unique_ptr analyze( + const std::string& name, + const std::unordered_map>& args); + + static std::unique_ptr createDataProcessor( + const std::string& name, + const TableFunctionHandlePtr& handle, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* stringAllocator, + const velox::core::QueryConfig& config); + + static std::unique_ptr + defaultCreateDataProcessor( + const TableFunctionHandlePtr& /* handle */, + velox::memory::MemoryPool* /* pool */, + velox::HashStringAllocator* /* stringAllocator */, + const velox::core::QueryConfig& /* config */) { + VELOX_NYI("TableFunction::createDataProcessor is not implemented"); + } + + static std::unique_ptr createSplitProcessor( + const std::string& name, + const TableFunctionHandlePtr& handle, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* stringAllocator, + const velox::core::QueryConfig& config); + + static std::unique_ptr + defaultCreateSplitProcessor( + const TableFunctionHandlePtr& /* handle */, + velox::memory::MemoryPool* /* pool */, + velox::HashStringAllocator* /* stringAllocator */, + const velox::core::QueryConfig& /* config */) { + VELOX_NYI("TableFunction::createSplitProcessor is not implemented"); + } + + static std::vector getSplits( + const std::string& name, + const TableFunctionHandlePtr& handle); + + static std::vector defaultGetSplits( + const TableFunctionHandlePtr& /* handle */) { + VELOX_NYI("TableFunction::getSplits is not implemented"); + } +}; + +using TableFunctionAnalyzer = + std::function( + const std::unordered_map>& + args)>; + +using TableFunctionDataProcessorFactory = + std::function( + const TableFunctionHandlePtr& handle, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* stringAllocator, + const velox::core::QueryConfig& config)>; + +using TableFunctionSplitProcessorFactory = + std::function( + const TableFunctionHandlePtr& handle, + velox::memory::MemoryPool* pool, + velox::HashStringAllocator* stringAllocator, + const velox::core::QueryConfig& config)>; + +using TableFunctionSplitGenerator = + std::function( + const TableFunctionHandlePtr& handle)>; + +struct TableFunctionEntry { + TableArgumentSpecList argumentsSpec; + ReturnSpecPtr returnSpec; + TableFunctionAnalyzer analyzer; + TableFunctionDataProcessorFactory dataProcessorFactory; + TableFunctionSplitProcessorFactory splitProcessorFactory; + TableFunctionSplitGenerator splitGenerator; +}; + +/// Register a Table function with the specified name. +/// Registering a function with the same name a second time overrides the +/// first registration. +bool registerTableFunction( + const std::string& name, + TableArgumentSpecList argumentsSpec, + ReturnSpecPtr returnSpec, + TableFunctionAnalyzer analyzer, + TableFunctionDataProcessorFactory dataProcessorFactory = + TableFunction::defaultCreateDataProcessor, + TableFunctionSplitProcessorFactory splitProcessorFactory = + TableFunction::defaultCreateSplitProcessor, + TableFunctionSplitGenerator splitGenerator = + TableFunction::defaultGetSplits); + +ReturnSpecPtr getTableFunctionReturnType(const std::string& name); + +TableArgumentSpecList getTableFunctionArgumentSpecs(const std::string& name); + +using TableFunctionMap = std::unordered_map; + +/// Returns a map of all Table function names to their registrations. +TableFunctionMap& tableFunctions(); +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/TableFunctionAnalysis.h b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunctionAnalysis.h new file mode 100644 index 0000000000000..dcb1df786f7d0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunctionAnalysis.h @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/spi/Descriptor.h" + +#include "velox/common/serialization/DeserializationRegistry.h" +#include "velox/core/Expressions.h" + +namespace facebook::presto::tvf { + +class TableFunctionHandle : public velox::ISerializable { + public: + // This name is used for looking up the deserializer for the + // TableFunctionHandle in the registry. + virtual std::string_view name() const = 0; + + virtual folly::dynamic serialize() const = 0; +}; + +using TableFunctionHandlePtr = std::shared_ptr; + +class TableSplitHandle : public velox::ISerializable { + public: + // This name is used for looking up the deserializer for the + // TableSplitHandle in the registry. + virtual std::string_view name() const = 0; + + virtual folly::dynamic serialize() const = 0; +}; + +using TableSplitHandlePtr = std::shared_ptr; + +using RequiredColumnsMap = + std::unordered_map>; + +// TODO : Make this implement ISerializable as well ? +class TableFunctionAnalysis { + public: + TableFunctionAnalysis() {} + + const TableFunctionHandlePtr tableFunctionHandle() const { + return tableFunctionHandle_; + } + + const DescriptorPtr returnType() const { + return returnType_; + } + + const RequiredColumnsMap requiredColumns() const { + return requiredColumns_; + } + + // Add a builder so that these can be set outside. + // protected: + DescriptorPtr returnType_; + TableFunctionHandlePtr tableFunctionHandle_; + RequiredColumnsMap requiredColumns_; +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/spi/TableFunctionResult.h b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunctionResult.h new file mode 100644 index 0000000000000..fd291cef1aa3a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/spi/TableFunctionResult.h @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// #include "velox/core/Expressions.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::presto::tvf { + +class TableFunctionResult { + public: + enum class TableFunctionState { + kBlocked, + kFinished, + kProcessed, + }; + + TableFunctionResult(TableFunctionState state) : state_(state) { + VELOX_CHECK(state == TableFunctionState::kFinished); + } + + TableFunctionResult(bool usedInput, velox::RowVectorPtr result) + : state_(TableFunctionState::kProcessed), + usedInput_(usedInput), + result_(std::move(result)) {} + + TableFunctionResult::TableFunctionState state() const { + return state_; + } + + bool usedInput() const { + return usedInput_; + } + + [[nodiscard]] velox::RowVectorPtr result() const { + return result_; + } + + private: + TableFunctionState state_; + + bool usedInput_; + velox::RowVectorPtr result_; +}; + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/tvf/tests/CMakeLists.txt new file mode 100644 index 0000000000000..92f5171a36b1b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/CMakeLists.txt @@ -0,0 +1,46 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_tvf_plan_builder PlanBuilder.cpp) + +target_link_libraries( + presto_tvf_plan_builder + presto_tvf_core + presto_tvf_functions + presto_tvf_spi + velox_core +) + +add_executable(presto_tvf_test ExcludeColumnsTest.cpp PlanNodeSerdeTest.cpp SequenceTest.cpp TableFunctionsInvocationTest.cpp) + +add_test(presto_tvf_test presto_tvf_test) + +target_link_libraries( + presto_tvf_test + presto_tvf_core + presto_tvf_exec + presto_tvf_functions + presto_tvf_plan_builder + presto_tvf_spi + velox_vector_fuzzer + velox_exec_test_lib + velox_vector_test_lib + velox_type + velox_vector + velox_exec + velox_memory + velox_exec + GTest::gmock + GTest::gtest + GTest::gtest_main +) + +set_property(TARGET presto_tvf_test PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/ExcludeColumnsTest.cpp b/presto-native-execution/presto_cpp/main/tvf/tests/ExcludeColumnsTest.cpp new file mode 100644 index 0000000000000..5d00c33ae4491 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/ExcludeColumnsTest.cpp @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "presto_cpp/main/tvf/exec/TableFunctionOperator.h" +#include "presto_cpp/main/tvf/exec/TableFunctionTranslator.h" +#include "presto_cpp/main/tvf/functions/TableFunctionsRegistration.h" +#include "presto_cpp/main/tvf/tests/PlanBuilder.h" + +#include "velox/core/PlanNode.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" + +using namespace facebook::presto::tvf; + +namespace facebook::velox::exec::test { +class ExcludeColumnsTest : public OperatorTestBase { + protected: + void SetUp() override { + exec::test::OperatorTestBase::SetUp(); + } + + ExcludeColumnsTest() { + functions::prestosql::registerAllScalarFunctions(); + aggregate::prestosql::registerAllAggregateFunctions(); + registerAllTableFunctions(""); + parse::registerTypeResolver(); + + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + // This code is added in PrestoToVeloxQueryPlan. + auto& registry = DeserializationWithContextRegistryForSharedPtr(); + registry.Register( + "TableFunctionProcessorNode", + presto::tvf::TableFunctionProcessorNode::create); + + velox::exec::Operator::registerOperator( + std::make_unique()); + } +}; + +TEST_F(ExcludeColumnsTest, basic) { + auto data = makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({10, 20, 30}), + makeConstant(true, 3), + }); + auto type = asRowType(data->type()); + + std::unordered_map> args; + std::vector excludeColumnNames = {"c0"}; + auto excludeColumnsDesc = std::make_shared(excludeColumnNames); + args.insert({"COLUMNS", excludeColumnsDesc}); + auto inputDesc = std::make_shared(type); + args.insert({"INPUT", inputDesc}); + + auto plan = exec::test::PlanBuilder() + .values({data}) + .addNode(addTvfNode("exclude_columns", args)) + .planNode(); + auto expected = makeRowVector({ + makeFlatVector({10, 20, 30}), + makeConstant(true, 3), + }); + assertQuery(plan, expected); +} +} // namespace facebook::velox::exec::test diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/PlanBuilder.cpp b/presto-native-execution/presto_cpp/main/tvf/tests/PlanBuilder.cpp new file mode 100644 index 0000000000000..b40847f81a8be --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/PlanBuilder.cpp @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/tvf/tests/PlanBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::core; + +namespace facebook::presto::tvf { + +std::function< + velox::core::PlanNodePtr(std::string nodeId, velox::core::PlanNodePtr)> +addTvfNode( + const std::string& name, + const std::unordered_map>& args, + const std::vector& partitionKeys, + const std::vector& sortingKeys, + const std::vector& sortingOrders) { + return [&name, &args, &partitionKeys, &sortingKeys, &sortingOrders](PlanNodeId nodeId, PlanNodePtr source) -> PlanNodePtr { + // Validate the user has provided all required arguments. + auto argsList = getTableFunctionArgumentSpecs(name); + for (const auto arg : argsList) { + if (arg->required()) { + VELOX_CHECK_GT(args.count(arg->name()), 0); + } + } + + auto analysis = TableFunction::analyze(name, args); + VELOX_CHECK(analysis); + VELOX_CHECK(analysis->tableFunctionHandle()); + + RowTypePtr outputType; + auto returnTypeSpec = getTableFunctionReturnType(name); + if (returnTypeSpec->returnType() == + ReturnTypeSpecification::ReturnType::kGenericTable) { + VELOX_CHECK(analysis->returnType()); + + auto names = analysis->returnType()->names(); + auto types = analysis->returnType()->types(); + outputType = velox::ROW(std::move(names), std::move(types)); + } else { + auto describedTableSpec = + std::dynamic_pointer_cast(returnTypeSpec); + auto names = describedTableSpec->descriptor()->names(); + auto types = describedTableSpec->descriptor()->types(); + outputType = velox::ROW(std::move(names), std::move(types)); + } + + std::vector sources; + if (source == nullptr) { + sources.clear(); + } else { + sources.push_back(source); + } + + std::vector requiredColumns = {}; + if (analysis->requiredColumns().count("INPUT")) { + requiredColumns = analysis->requiredColumns().at("INPUT"); + } + + return std::make_shared( + nodeId, + name, + analysis->tableFunctionHandle(), + partitionKeys, + sortingKeys, + sortingOrders, + outputType, + requiredColumns, + sources); + }; +} +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/PlanBuilder.h b/presto-native-execution/presto_cpp/main/tvf/tests/PlanBuilder.h new file mode 100644 index 0000000000000..96162648940bb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/PlanBuilder.h @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" +#include "presto_cpp/main/tvf/spi/TableFunction.h" +#include "velox/core/PlanNode.h" + +namespace facebook::presto::tvf { + +// Helper functions to use with PlanBuilder::addNode. + +std::function< + velox::core::PlanNodePtr(std::string nodeId, velox::core::PlanNodePtr)> +addTvfNode( + const std::string& name, + const std::unordered_map>& args = {}, + const std::vector& partitionKeys = {}, + const std::vector& sortingKeys = {}, + const std::vector& sortingOrders = {}); + +} // namespace facebook::presto::tvf diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/PlanNodeSerdeTest.cpp b/presto-native-execution/presto_cpp/main/tvf/tests/PlanNodeSerdeTest.cpp new file mode 100644 index 0000000000000..3078f9e48298b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/PlanNodeSerdeTest.cpp @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "presto_cpp/main/tvf/functions/TableFunctionsRegistration.h" +#include "presto_cpp/main/tvf/tests/PlanBuilder.h" + +#include "velox/core/PlanNode.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::presto::tvf; + +namespace facebook::velox::exec::test { +class PlanNodeSerdeTest : public testing::Test, + public velox::test::VectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance(memory::MemoryManager::Options{}); + } + + PlanNodeSerdeTest() { + functions::prestosql::registerAllScalarFunctions(); + aggregate::prestosql::registerAllAggregateFunctions(); + registerAllTableFunctions(""); + parse::registerTypeResolver(); + + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + // This code is added in PrestoToVeloxQueryPlan. + auto& registry = DeserializationWithContextRegistryForSharedPtr(); + registry.Register( + "TableFunctionProcessorNode", + presto::tvf::TableFunctionProcessorNode::create); + + data_ = {makeRowVector({ + makeFlatVector({1, 2, 3}), + makeFlatVector({10, 20, 30}), + makeConstant(true, 3), + })}; + type_ = asRowType(data_[0]->type()); + } + + void testSerde(const core::PlanNodePtr& plan) { + auto serialized = plan->serialize(); + + auto copy = + velox::ISerializable::deserialize(serialized, pool()); + + LOG(INFO) << "\nplan->toString" << plan->toString(true, true) << "\n"; + LOG(INFO) << "\ncopy->toString" << copy->toString(true, true) << "\n"; + ASSERT_EQ(plan->toString(true, true), copy->toString(true, true)); + } + + /*static std::vector reverseColumns(const RowTypePtr& rowType) { + auto names = rowType->names(); + std::reverse(names.begin(), names.end()); + return names; + }*/ + + std::vector data_; + RowTypePtr type_; +}; + +TEST_F(PlanNodeSerdeTest, excludeColumns) { + std::unordered_map> args; + std::vector excludeColumnNames = {"c0"}; + auto excludeColumnsDesc = std::make_shared(excludeColumnNames); + args.insert({"COLUMNS", excludeColumnsDesc}); + auto inputDesc = std::make_shared(type_); + args.insert({"INPUT", inputDesc}); + auto plan = exec::test::PlanBuilder() + .values(data_, true) + .addNode(addTvfNode("exclude_columns", args)) + .planNode(); + testSerde(plan); +} + +TEST_F(PlanNodeSerdeTest, sequence) { + std::unordered_map> args; + args.insert( + {"START", + std::make_shared( + BIGINT(), makeConstant(static_cast(1), 1, BIGINT()))}); + args.insert( + {"STOP", + std::make_shared( + BIGINT(), makeConstant(static_cast(10), 1, BIGINT()))}); + args.insert( + {"STEP", + std::make_shared( + BIGINT(), makeConstant(static_cast(1), 1, BIGINT()))}); + auto plan = exec::test::PlanBuilder() + .addNode(addTvfNode("sequence", args)) + .planNode(); + testSerde(plan); +} +} // namespace facebook::velox::exec::test diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/SequenceTest.cpp b/presto-native-execution/presto_cpp/main/tvf/tests/SequenceTest.cpp new file mode 100644 index 0000000000000..84427d66db995 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/SequenceTest.cpp @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" +#include "presto_cpp/main/tvf/exec/TableFunctionOperator.h" +#include "presto_cpp/main/tvf/exec/TableFunctionSplit.h" +#include "presto_cpp/main/tvf/exec/TableFunctionTranslator.h" +#include "presto_cpp/main/tvf/functions/TableFunctionsRegistration.h" +#include "presto_cpp/main/tvf/tests/PlanBuilder.h" + +#include "velox/core/PlanNode.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/parse/TypeResolver.h" + +using namespace facebook::presto::tvf; + +namespace facebook::velox::exec::test { + +class SequenceTest : public OperatorTestBase { + protected: + void SetUp() override { + exec::test::OperatorTestBase::SetUp(); + } + + SequenceTest() { + functions::prestosql::registerAllScalarFunctions(); + aggregate::prestosql::registerAllAggregateFunctions(); + registerAllTableFunctions(""); + parse::registerTypeResolver(); + + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + // This code is added in PrestoToVeloxQueryPlan. + auto& registry = DeserializationWithContextRegistryForSharedPtr(); + registry.Register( + "TableFunctionProcessorNode", + presto::tvf::TableFunctionProcessorNode::create); + + velox::exec::Operator::registerOperator( + std::make_unique()); + } + + std::unordered_map> + sequenceArgs(int64_t start, int64_t stop, int64_t step) { + std::unordered_map> args; + args.insert( + {"START", + std::make_shared( + BIGINT(), makeConstant(start, 1, BIGINT()))}); + args.insert( + {"STOP", + std::make_shared( + BIGINT(), makeConstant(stop, 1, BIGINT()))}); + args.insert( + {"STEP", + std::make_shared( + BIGINT(), makeConstant(step, 1, BIGINT()))}); + + return args; + } + + std::vector splitsForTvf(const core::PlanNodePtr& node) { + auto sequenceTvfNode = + dynamic_pointer_cast(node); + auto sequenceSplits = + TableFunction::getSplits("sequence", sequenceTvfNode->handle()); + std::vector tvfSplits; + for (auto sequenceSplit : sequenceSplits) { + auto tableFunctionSplit = + std::make_shared(sequenceSplit); + tvfSplits.push_back(velox::exec::Split(tableFunctionSplit)); + } + + return tvfSplits; + } +}; + +TEST_F(SequenceTest, basic) { + auto plan = exec::test::PlanBuilder() + .addNode(addTvfNode("sequence", sequenceArgs(10, 30, 2))) + .planNode(); + + auto expected = makeRowVector( + {makeFlatVector({10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30})}); + + auto sequenceTvfNode = + dynamic_pointer_cast(plan); + auto sequenceSplits = + TableFunction::getSplits("sequence", sequenceTvfNode->handle()); + std::vector tvfSplits; + for (auto sequenceSplit : sequenceSplits) { + auto tableFunctionSplit = + std::make_shared(sequenceSplit); + tvfSplits.push_back(velox::exec::Split(tableFunctionSplit)); + } + + AssertQueryBuilder(plan).splits(splitsForTvf(plan)).assertResults(expected); +} + +TEST_F(SequenceTest, join) { + auto planNodeIdGenerator = std::make_shared(); + + core::PlanNodeId sourceId1; + auto source1 = exec::test::PlanBuilder(planNodeIdGenerator) + .addNode(addTvfNode("sequence", sequenceArgs(10, 30, 2))) + .capturePlanNodeId(sourceId1) + .planNode(); + + core::PlanNodeId sourceId2; + core::PlanNodePtr source2; + auto plan = + exec::test::PlanBuilder(planNodeIdGenerator) + .addNode(addTvfNode("sequence", sequenceArgs(20, 30, 2))) + .capturePlanNodeId(sourceId2) + .capturePlanNode(source2) + .project({"sequential_number AS left_sequence"}) + .nestedLoopJoin( + source1, "sequential_number = left_sequence", {"left_sequence"}) + .planNode(); + + auto expected = + makeRowVector({makeFlatVector({20, 22, 24, 26, 28, 30})}); + + AssertQueryBuilder(plan) + .splits(sourceId1, splitsForTvf(source1)) + .splits(sourceId2, splitsForTvf(source2)) + .assertResults(expected); +} + +} // namespace facebook::velox::exec::test diff --git a/presto-native-execution/presto_cpp/main/tvf/tests/TableFunctionsInvocationTest.cpp b/presto-native-execution/presto_cpp/main/tvf/tests/TableFunctionsInvocationTest.cpp new file mode 100644 index 0000000000000..e2129cc5701b1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/tvf/tests/TableFunctionsInvocationTest.cpp @@ -0,0 +1,149 @@ +/* +* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "presto_cpp/main/tvf/tests/PlanBuilder.h" +#include "presto_cpp/main/tvf/exec/TableFunctionTranslator.h" +#include "presto_cpp/main/tvf/functions/TestingTableFunctions.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/parse/TypeResolver.h" + +using namespace facebook::presto::tvf; + +namespace facebook::velox::exec::test { +class TableFunctionInvocationTest : public OperatorTestBase { + protected: + void SetUp() override { + OperatorTestBase::SetUp(); + }; + + public: + TableFunctionInvocationTest() { + registerSimpleTableFunction("simple_table_function"); + registerIdentityFunction("identity_table_function"); + registerRepeatFunction("repeat_table_function"); + parse::registerTypeResolver(); + + Type::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + + // This code is added in PrestoToVeloxQueryPlan. + auto& registry = DeserializationWithContextRegistryForSharedPtr(); + registry.Register( + "TableFunctionProcessorNode", + presto::tvf::TableFunctionProcessorNode::create); + + velox::exec::Operator::registerOperator( + std::make_unique()); + }; + + protected: + std::unordered_map> simpleTableFunctionArgs(const std::string& column) { + std::unordered_map> args; + args.emplace("COLUMN", std::make_shared(VARCHAR(), makeConstant(StringView(column), 1, VARCHAR()))); + return args; + } +}; + +TEST_F(TableFunctionInvocationTest, DISABLED_simple) { + auto plan = PlanBuilder() + .addNode(addTvfNode("simple_table_function", simpleTableFunctionArgs("col"))) + .planNode(); + + auto expected = makeRowVector({}); + AssertQueryBuilder(plan).assertResults(expected); +} + +TEST_F(TableFunctionInvocationTest, identity) { + auto data = makeRowVector({ + makeFlatVector({1,2,3}), + makeFlatVector({10,20,30}) + }); + auto type = asRowType(data->type()); + std::unordered_map> args; + auto input = std::make_shared(type); + args.emplace("INPUT", input); + + auto plan = PlanBuilder() + .values({data}) + .addNode(addTvfNode("identity_table_function", args)) + .planNode(); + + assertQuery(plan, data); +} + +TEST_F(TableFunctionInvocationTest, repeat) { + auto data = makeRowVector({ + makeFlatVector({1,2,3}), + makeFlatVector({10,20,30}) + }); + auto type = asRowType(data->type()); + std::unordered_map> args; + auto input = std::make_shared(type); + args.insert({"INPUT", input}); + std::shared_ptr count = std::make_shared(BIGINT(), makeConstant(static_cast(2), 1, BIGINT())); + args.insert({"COUNT", count}); + + auto plan = PlanBuilder() + .values({data}) + .addNode(addTvfNode("repeat_table_function", args)) + .planNode(); + + auto expected = makeRowVector({ + makeFlatVector({1,2,3,1,2,3}), + makeFlatVector({10,20,30,10,20,30}) + }); + + assertQuery(plan, expected); +} + +TEST_F(TableFunctionInvocationTest, repeatPartitionOrder) { + auto data = makeRowVector({ + makeFlatVector({1,1,2,3}), + makeFlatVector({10,20,20,30}) + }); + auto type = asRowType(data->type()); + std::unordered_map> args; + auto input = std::make_shared(type); + args.insert({"INPUT", input}); + std::shared_ptr count = std::make_shared(BIGINT(), makeConstant(static_cast(2), 1, BIGINT())); + args.insert({"COUNT", count}); + + std::vector partitions = { + std::make_shared(BIGINT(), "c0") + }; + std::vector sorts = { + std::make_shared(BIGINT(), "c1") + }; + std::vector sortOrders = { + {true, false} + }; + + auto plan = PlanBuilder() + .values({data}) + .addNode(addTvfNode("repeat_table_function", args, partitions, sorts, sortOrders)) + .planNode(); + + auto expected = makeRowVector({ + makeFlatVector({1,1,2,3,1,1,2,3}), + makeFlatVector({10,20,20,30,10,20,20,30}) + }); + + assertQuery(plan, expected); +} + +} // namespace facebook::velox::exec::test \ No newline at end of file diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index ad15b13f16e7d..cd7521a132bb6 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -23,6 +23,7 @@ endif() add_library(presto_types PrestoToVeloxQueryPlan.cpp VeloxPlanValidator.cpp PrestoToVeloxSplit.cpp) target_link_libraries( presto_types + presto_tvf_core presto_velox_expr_conversion presto_connectors presto_operators diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index a2c445d19d5f0..fc1e3944d8213 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -34,6 +34,8 @@ #include "presto_cpp/main/operators/PartitionAndSerialize.h" #include "presto_cpp/main/operators/ShuffleRead.h" #include "presto_cpp/main/operators/ShuffleWrite.h" +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" + #include "presto_cpp/main/types/TypeParser.h" #include "velox/exec/TraceUtil.h" @@ -1841,6 +1843,64 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( toVeloxQueryPlan(node->source, tableWriteInfo, taskId)); } +std::shared_ptr +VeloxQueryPlanConverterBase::toVeloxQueryPlan( + const std::shared_ptr& node, + const std::shared_ptr& tableWriteInfo, + const protocol::TaskId& taskId) { + const auto outputType = toRowType(node->properOutputs, typeParser_); + + std::vector requiredColumns; + std::vector sources; + if (node->source) { + const auto sourceNode = + toVeloxQueryPlan(*node->source, tableWriteInfo, taskId); + sources.push_back(sourceNode); + const auto inputType = sourceNode->outputType(); + for (const auto& variables : node->requiredVariables) { + for (const auto& expr : toVeloxExprs(variables)) { + requiredColumns.push_back(exprToChannel(expr.get(), inputType)); + } + } + } + + auto handle = std::dynamic_pointer_cast( + node->handle.functionHandle); + VELOX_CHECK_NOT_NULL(handle, "Invalid table function handle {}", toJsonString(node->handle)); + + auto functionName = handle->functionName; // fully qualified function name + auto tableFunctionHandlePtr = + ISerializable::deserialize( + folly::parseJson(handle->serializedTableFunctionHandle)); + + std::vector partitionKeys; + std::vector sortingKeys; + std::vector sortingOrders; + + if (node->specification) { + if (!node->specification->partitionBy.empty()) { + partitionKeys = toVeloxExprs(node->specification->partitionBy); + } + if (!node->specification->orderingScheme) { + for (const auto& orderby : node->specification->orderingScheme->orderBy) { + sortingKeys.emplace_back(exprConverter_.toVeloxExpr(orderby.variable)); + sortingOrders.push_back(toVeloxSortOrder(orderby.sortOrder)); + } + } + } + + return std::make_shared( + node->id, + std::move(functionName), + tableFunctionHandlePtr, + std::move(partitionKeys), + std::move(sortingKeys), + std::move(sortingOrders), + outputType, + std::move(requiredColumns), + std::move(sources)); +} + core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, const std::shared_ptr& tableWriteInfo, @@ -1966,6 +2026,11 @@ core::PlanNodePtr VeloxQueryPlanConverterBase::toVeloxQueryPlan( // directly, and does not have the intermediate SampleNode. return toVeloxQueryPlan(sampleNode->source, tableWriteInfo, taskId); } + if (auto tableFunctionProcessor = + std::dynamic_pointer_cast( + node)) { + return toVeloxQueryPlan(tableFunctionProcessor, tableWriteInfo, taskId); + } VELOX_UNSUPPORTED("Unknown plan node type {}", node->_type); } @@ -2370,6 +2435,9 @@ void registerPrestoPlanNodeSerDe() { "ShuffleWriteNode", presto::operators::ShuffleWriteNode::create); registry.Register( "BroadcastWriteNode", presto::operators::BroadcastWriteNode::create); + registry.Register( + "TableFunctionProcessorNode", + presto::tvf::TableFunctionProcessorNode::create); } void parseSqlFunctionHandle( diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h index 0bd111e09cbb8..6a03d3d84fa8b 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.h @@ -15,6 +15,7 @@ #include #include #include "presto_cpp/main/operators/ShuffleInterface.h" +#include "presto_cpp/main/tvf/core/TableFunctionProcessorNode.h" #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/core/Expressions.h" #include "velox/core/PlanFragment.h" @@ -195,6 +196,11 @@ class VeloxQueryPlanConverterBase { const std::shared_ptr& tableWriteInfo, const protocol::TaskId& taskId); + std::shared_ptr toVeloxQueryPlan( + const std::shared_ptr& node, + const std::shared_ptr& tableWriteInfo, + const protocol::TaskId& taskId); + std::vector toVeloxExprs( const std::vector& variables); diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp index cb77fdfffa45d..f0987e30f040c 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp @@ -370,10 +370,9 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BucketFunctionType_enum_table[] = - { // NOLINT: cert-err58-cpp - {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, - {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; + BucketFunctionType_enum_table[] = { // NOLINT: cert-err58-cpp + {BucketFunctionType::HIVE_COMPATIBLE, "HIVE_COMPATIBLE"}, + {BucketFunctionType::PRESTO_NATIVE, "PRESTO_NATIVE"}}; void to_json(json& j, const BucketFunctionType& e) { static_assert( std::is_enum::value, @@ -599,13 +598,12 @@ namespace facebook::presto::protocol::hive { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - HiveCompressionCodec_enum_table[] = - { // NOLINT: cert-err58-cpp - {HiveCompressionCodec::NONE, "NONE"}, - {HiveCompressionCodec::SNAPPY, "SNAPPY"}, - {HiveCompressionCodec::GZIP, "GZIP"}, - {HiveCompressionCodec::LZ4, "LZ4"}, - {HiveCompressionCodec::ZSTD, "ZSTD"}}; + HiveCompressionCodec_enum_table[] = { // NOLINT: cert-err58-cpp + {HiveCompressionCodec::NONE, "NONE"}, + {HiveCompressionCodec::SNAPPY, "SNAPPY"}, + {HiveCompressionCodec::GZIP, "GZIP"}, + {HiveCompressionCodec::LZ4, "LZ4"}, + {HiveCompressionCodec::ZSTD, "ZSTD"}}; void to_json(json& j, const HiveCompressionCodec& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp index 6674b05f847b4..3afd250c0ea1a 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp @@ -25,12 +25,11 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - ChangelogOperation_enum_table[] = - { // NOLINT: cert-err58-cpp - {ChangelogOperation::INSERT, "INSERT"}, - {ChangelogOperation::DELETE, "DELETE"}, - {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, - {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; + ChangelogOperation_enum_table[] = { // NOLINT: cert-err58-cpp + {ChangelogOperation::INSERT, "INSERT"}, + {ChangelogOperation::DELETE, "DELETE"}, + {ChangelogOperation::UPDATE_BEFORE, "UPDATE_BEFORE"}, + {ChangelogOperation::UPDATE_AFTER, "UPDATE_AFTER"}}; void to_json(json& j, const ChangelogOperation& e) { static_assert( std::is_enum::value, @@ -878,15 +877,14 @@ namespace facebook::presto::protocol::iceberg { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - PartitionTransformType_enum_table[] = - { // NOLINT: cert-err58-cpp - {PartitionTransformType::IDENTITY, "IDENTITY"}, - {PartitionTransformType::HOUR, "HOUR"}, - {PartitionTransformType::DAY, "DAY"}, - {PartitionTransformType::MONTH, "MONTH"}, - {PartitionTransformType::YEAR, "YEAR"}, - {PartitionTransformType::BUCKET, "BUCKET"}, - {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; + PartitionTransformType_enum_table[] = { // NOLINT: cert-err58-cpp + {PartitionTransformType::IDENTITY, "IDENTITY"}, + {PartitionTransformType::HOUR, "HOUR"}, + {PartitionTransformType::DAY, "DAY"}, + {PartitionTransformType::MONTH, "MONTH"}, + {PartitionTransformType::YEAR, "YEAR"}, + {PartitionTransformType::BUCKET, "BUCKET"}, + {PartitionTransformType::TRUNCATE, "TRUNCATE"}}; void to_json(json& j, const PartitionTransformType& e) { static_assert( std::is_enum::value, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h index 130944584c90f..11fa34edd2638 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/ConnectorProtocol.h @@ -426,4 +426,16 @@ using SystemConnectorProtocol = ConnectorProtocolTemplate< NotImplemented, NotImplemented>; +using TvfNativeConnectorProtocol = ConnectorProtocolTemplate< + SystemTableHandle, + SystemTableLayoutHandle, + SystemColumnHandle, + NotImplemented, + NotImplemented, + NativeTableFunctionSplit, + SystemPartitioningHandle, + SystemTransactionHandle, + NotImplemented, + NotImplemented>; + } // namespace facebook::presto::protocol diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 56e35b97348de..334d1f4f11034 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -36,11 +36,10 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - NodeSelectionStrategy_enum_table[] = - { // NOLINT: cert-err58-cpp - {NodeSelectionStrategy::HARD_AFFINITY, "HARD_AFFINITY"}, - {NodeSelectionStrategy::SOFT_AFFINITY, "SOFT_AFFINITY"}, - {NodeSelectionStrategy::NO_PREFERENCE, "NO_PREFERENCE"}}; + NodeSelectionStrategy_enum_table[] = { // NOLINT: cert-err58-cpp + {NodeSelectionStrategy::HARD_AFFINITY, "HARD_AFFINITY"}, + {NodeSelectionStrategy::SOFT_AFFINITY, "SOFT_AFFINITY"}, + {NodeSelectionStrategy::NO_PREFERENCE, "NO_PREFERENCE"}}; void to_json(json& j, const NodeSelectionStrategy& e) { static_assert( std::is_enum::value, @@ -102,6 +101,177 @@ std::string json_map_key(const VariableReferenceExpression& p) { } } // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "descriptor") { + j = *std::static_pointer_cast(p); + return; + } + if (type == "scalar") { + j = *std::static_pointer_cast(p); + return; + } + if (type == "table") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ArgumentSpecification "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError( + std::string(e.what()) + + " ArgumentSpecification ArgumentSpecification"); + } + + if (type == "descriptor") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == "scalar") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == "table") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ArgumentSpecification "); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "described_table") { + j = *std::static_pointer_cast(p); + return; + } + if (type == "generic_table") { + j = *std::static_pointer_cast(p); + return; + } + if (type == "only_pass_through_table") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ReturnTypeSpecification "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError( + std::string(e.what()) + + " ReturnTypeSpecification ReturnTypeSpecification"); + } + + if (type == "described_table") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == "generic_table") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == "only_pass_through_table") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ReturnTypeSpecification "); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const AbstractConnectorTableFunction& p) { + j = json::object(); + to_json_key( + j, + "schema", + p.schema, + "AbstractConnectorTableFunction", + "String", + "schema"); + to_json_key( + j, "name", p.name, "AbstractConnectorTableFunction", "String", "name"); + to_json_key( + j, + "arguments", + p.arguments, + "AbstractConnectorTableFunction", + "List>", + "arguments"); + to_json_key( + j, + "returnTypeSpecification", + p.returnTypeSpecification, + "AbstractConnectorTableFunction", + "ReturnTypeSpecification", + "returnTypeSpecification"); +} + +void from_json(const json& j, AbstractConnectorTableFunction& p) { + from_json_key( + j, + "schema", + p.schema, + "AbstractConnectorTableFunction", + "String", + "schema"); + from_json_key( + j, "name", p.name, "AbstractConnectorTableFunction", "String", "name"); + from_json_key( + j, + "arguments", + p.arguments, + "AbstractConnectorTableFunction", + "List>", + "arguments"); + from_json_key( + j, + "returnTypeSpecification", + p.returnTypeSpecification, + "AbstractConnectorTableFunction", + "ReturnTypeSpecification", + "returnTypeSpecification"); +} +} // namespace facebook::presto::protocol namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { if (p == nullptr) { @@ -561,12 +731,11 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - AggregationNodeStep_enum_table[] = - { // NOLINT: cert-err58-cpp - {AggregationNodeStep::PARTIAL, "PARTIAL"}, - {AggregationNodeStep::FINAL, "FINAL"}, - {AggregationNodeStep::INTERMEDIATE, "INTERMEDIATE"}, - {AggregationNodeStep::SINGLE, "SINGLE"}}; + AggregationNodeStep_enum_table[] = { // NOLINT: cert-err58-cpp + {AggregationNodeStep::PARTIAL, "PARTIAL"}, + {AggregationNodeStep::FINAL, "FINAL"}, + {AggregationNodeStep::INTERMEDIATE, "INTERMEDIATE"}, + {AggregationNodeStep::SINGLE, "SINGLE"}}; void to_json(json& j, const AggregationNodeStep& e) { static_assert( std::is_enum::value, @@ -781,6 +950,15 @@ void to_json(json& j, const std::shared_ptr& p) { j = *std::static_pointer_cast(p); return; } + if (type == "com.facebook.presto.sql.planner.plan.TableFunctionNode") { + j = *std::static_pointer_cast(p); + return; + } + if (type == + "com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode") { + j = *std::static_pointer_cast(p); + return; + } throw TypeError(type + " no abstract type PlanNode "); } @@ -985,6 +1163,21 @@ void from_json(const json& j, std::shared_ptr& p) { p = std::static_pointer_cast(k); return; } + if (type == "com.facebook.presto.sql.planner.plan.TableFunctionNode") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == + "com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } throw TypeError(type + " no abstract type PlanNode "); } @@ -2847,11 +3040,10 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - BuiltInFunctionKind_enum_table[] = - { // NOLINT: cert-err58-cpp - {BuiltInFunctionKind::ENGINE, "ENGINE"}, - {BuiltInFunctionKind::PLUGIN, "PLUGIN"}, - {BuiltInFunctionKind::WORKER, "WORKER"}}; + BuiltInFunctionKind_enum_table[] = { // NOLINT: cert-err58-cpp + {BuiltInFunctionKind::ENGINE, "ENGINE"}, + {BuiltInFunctionKind::PLUGIN, "PLUGIN"}, + {BuiltInFunctionKind::WORKER, "WORKER"}}; void to_json(json& j, const BuiltInFunctionKind& e) { static_assert( std::is_enum::value, @@ -3391,6 +3583,97 @@ void from_json(const json& j, Column& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "descriptor") { + j = *std::static_pointer_cast(p); + return; + } + if (type == "scalar") { + j = *std::static_pointer_cast(p); + return; + } + if (type == "table") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type Argument "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError(std::string(e.what()) + " Argument Argument"); + } + + if (type == "descriptor") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == "scalar") { + std::shared_ptr k = std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + if (type == "table") { + std::shared_ptr k = std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type Argument "); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const ConnectorTableMetadata& p) { + j = json::object(); + to_json_key( + j, + "functionName", + p.functionName, + "ConnectorTableMetadata", + "QualifiedObjectName", + "functionName"); + to_json_key( + j, + "arguments", + p.arguments, + "ConnectorTableMetadata", + "Map>", + "arguments"); +} + +void from_json(const json& j, ConnectorTableMetadata& p) { + from_json_key( + j, + "functionName", + p.functionName, + "ConnectorTableMetadata", + "QualifiedObjectName", + "functionName"); + from_json_key( + j, + "arguments", + p.arguments, + "ConnectorTableMetadata", + "Map>", + "arguments"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { void to_json(json& j, const Block& p) { j = p.data; @@ -3784,6 +4067,133 @@ void from_json(const json& j, DeleteNode& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { + +void to_json(json& j, const Field& p) { + j = json::object(); + to_json_key(j, "name", p.name, "Field", "String", "name"); + to_json_key(j, "type", p.type, "Field", "Type", "type"); +} + +void from_json(const json& j, Field& p) { + from_json_key(j, "name", p.name, "Field", "String", "name"); + from_json_key(j, "type", p.type, "Field", "Type", "type"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const Descriptor& p) { + j = json::object(); + to_json_key(j, "fields", p.fields, "Descriptor", "List", "fields"); +} + +void from_json(const json& j, Descriptor& p) { + from_json_key(j, "fields", p.fields, "Descriptor", "List", "fields"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +DescribedTableReturnTypeSpecification:: + DescribedTableReturnTypeSpecification() noexcept { + _type = "described_table"; +} + +void to_json(json& j, const DescribedTableReturnTypeSpecification& p) { + j = json::object(); + j["@type"] = "described_table"; + to_json_key( + j, + "descriptor", + p.descriptor, + "DescribedTableReturnTypeSpecification", + "Descriptor", + "descriptor"); +} + +void from_json(const json& j, DescribedTableReturnTypeSpecification& p) { + p._type = j["@type"]; + from_json_key( + j, + "descriptor", + p.descriptor, + "DescribedTableReturnTypeSpecification", + "Descriptor", + "descriptor"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +DescriptorArgument::DescriptorArgument() noexcept { + _type = "descriptor"; +} + +void to_json(json& j, const DescriptorArgument& p) { + j = json::object(); + j["@type"] = "descriptor"; + to_json_key( + j, + "descriptor", + p.descriptor, + "DescriptorArgument", + "Descriptor", + "descriptor"); +} + +void from_json(const json& j, DescriptorArgument& p) { + p._type = j["@type"]; + from_json_key( + j, + "descriptor", + p.descriptor, + "DescriptorArgument", + "Descriptor", + "descriptor"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +DescriptorArgumentSpecification::DescriptorArgumentSpecification() noexcept { + _type = "descriptor"; +} + +void to_json(json& j, const DescriptorArgumentSpecification& p) { + j = json::object(); + j["@type"] = "descriptor"; + to_json_key( + j, "name", p.name, "DescriptorArgumentSpecification", "String", "name"); + to_json_key( + j, + "required", + p.required, + "DescriptorArgumentSpecification", + "bool", + "required"); + to_json_key( + j, + "defaultValue", + p.defaultValue, + "DescriptorArgumentSpecification", + "Descriptor", + "defaultValue"); +} + +void from_json(const json& j, DescriptorArgumentSpecification& p) { + p._type = j["@type"]; + from_json_key( + j, "name", p.name, "DescriptorArgumentSpecification", "String", "name"); + from_json_key( + j, + "required", + p.required, + "DescriptorArgumentSpecification", + "bool", + "required"); + from_json_key( + j, + "defaultValue", + p.defaultValue, + "DescriptorArgumentSpecification", + "Descriptor", + "defaultValue"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { DistinctLimitNode::DistinctLimitNode() noexcept { _type = ".DistinctLimitNode"; } @@ -5837,25 +6247,54 @@ void from_json(const json& j, Function& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -GroupIdNode::GroupIdNode() noexcept { - _type = "com.facebook.presto.sql.planner.plan.GroupIdNode"; +GenericTableReturnTypeSpecification:: + GenericTableReturnTypeSpecification() noexcept { + _type = "generic_table"; } -void to_json(json& j, const GroupIdNode& p) { +void to_json(json& j, const GenericTableReturnTypeSpecification& p) { j = json::object(); - j["@type"] = "com.facebook.presto.sql.planner.plan.GroupIdNode"; - to_json_key(j, "id", p.id, "GroupIdNode", "PlanNodeId", "id"); - to_json_key(j, "source", p.source, "GroupIdNode", "PlanNode", "source"); + j["@type"] = "generic_table"; to_json_key( j, - "groupingSets", - p.groupingSets, - "GroupIdNode", - "List>", - "groupingSets"); - to_json_key( - j, - "groupingColumns", + "returnType", + p.returnType, + "GenericTableReturnTypeSpecification", + "String", + "returnType"); +} + +void from_json(const json& j, GenericTableReturnTypeSpecification& p) { + p._type = j["@type"]; + from_json_key( + j, + "returnType", + p.returnType, + "GenericTableReturnTypeSpecification", + "String", + "returnType"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +GroupIdNode::GroupIdNode() noexcept { + _type = "com.facebook.presto.sql.planner.plan.GroupIdNode"; +} + +void to_json(json& j, const GroupIdNode& p) { + j = json::object(); + j["@type"] = "com.facebook.presto.sql.planner.plan.GroupIdNode"; + to_json_key(j, "id", p.id, "GroupIdNode", "PlanNodeId", "id"); + to_json_key(j, "source", p.source, "GroupIdNode", "PlanNode", "source"); + to_json_key( + j, + "groupingSets", + p.groupingSets, + "GroupIdNode", + "List>", + "groupingSets"); + to_json_key( + j, + "groupingColumns", p.groupingColumns, "GroupIdNode", "Map", @@ -6447,10 +6886,9 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - JoinDistributionType_enum_table[] = - { // NOLINT: cert-err58-cpp - {JoinDistributionType::PARTITIONED, "PARTITIONED"}, - {JoinDistributionType::REPLICATED, "REPLICATED"}}; + JoinDistributionType_enum_table[] = { // NOLINT: cert-err58-cpp + {JoinDistributionType::PARTITIONED, "PARTITIONED"}, + {JoinDistributionType::REPLICATED, "REPLICATED"}}; void to_json(json& j, const JoinDistributionType& e) { static_assert( std::is_enum::value, @@ -6669,6 +7107,57 @@ void from_json(const json& j, JoinNodeStatsEstimate& p) { } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +void to_json(json& j, const JsonBasedTableFunctionMetadata& p) { + j = json::object(); + to_json_key( + j, + "functionName", + p.functionName, + "JsonBasedTableFunctionMetadata", + "QualifiedObjectName", + "functionName"); + to_json_key( + j, + "arguments", + p.arguments, + "JsonBasedTableFunctionMetadata", + "List>", + "arguments"); + to_json_key( + j, + "returnTypeSpecification", + p.returnTypeSpecification, + "JsonBasedTableFunctionMetadata", + "ReturnTypeSpecification", + "returnTypeSpecification"); +} + +void from_json(const json& j, JsonBasedTableFunctionMetadata& p) { + from_json_key( + j, + "functionName", + p.functionName, + "JsonBasedTableFunctionMetadata", + "QualifiedObjectName", + "functionName"); + from_json_key( + j, + "arguments", + p.arguments, + "JsonBasedTableFunctionMetadata", + "List>", + "arguments"); + from_json_key( + j, + "returnTypeSpecification", + p.returnTypeSpecification, + "JsonBasedTableFunctionMetadata", + "ReturnTypeSpecification", + "returnTypeSpecification"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + void to_json(json& j, const JsonBasedUdfFunctionMetadata& p) { j = json::object(); to_json_key( @@ -7559,6 +8048,44 @@ void from_json(const json& j, MergeTarget& p) { } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +void to_json(json& j, const NativeField& p) { + j = json::object(); + to_json_key(j, "name", p.name, "NativeField", "String", "name"); + to_json_key( + j, + "typeSignature", + p.typeSignature, + "NativeField", + "TypeSignature", + "typeSignature"); +} + +void from_json(const json& j, NativeField& p) { + from_json_key(j, "name", p.name, "NativeField", "String", "name"); + from_json_key( + j, + "typeSignature", + p.typeSignature, + "NativeField", + "TypeSignature", + "typeSignature"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const NativeDescriptor& p) { + j = json::object(); + to_json_key( + j, "fields", p.fields, "NativeDescriptor", "List", "fields"); +} + +void from_json(const json& j, NativeDescriptor& p) { + from_json_key( + j, "fields", p.fields, "NativeDescriptor", "List", "fields"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + void to_json(json& j, const NativeSidecarFailureInfo& p) { j = json::object(); to_json_key(j, "type", p.type, "NativeSidecarFailureInfo", "String", "type"); @@ -7620,6 +8147,127 @@ void from_json(const json& j, NativeSidecarFailureInfo& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +NativeTableFunctionHandle::NativeTableFunctionHandle() noexcept { + _type = "system:com.facebook.presto.tvf.NativeTableFunctionHandle"; +} + +void to_json(json& j, const NativeTableFunctionHandle& p) { + j = json::object(); + j["@type"] = "system:com.facebook.presto.tvf.NativeTableFunctionHandle"; + to_json_key( + j, + "serializedTableFunctionHandle", + p.serializedTableFunctionHandle, + "NativeTableFunctionHandle", + "String", + "serializedTableFunctionHandle"); + to_json_key( + j, + "functionName", + p.functionName, + "NativeTableFunctionHandle", + "QualifiedObjectName", + "functionName"); +} + +void from_json(const json& j, NativeTableFunctionHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "serializedTableFunctionHandle", + p.serializedTableFunctionHandle, + "NativeTableFunctionHandle", + "String", + "serializedTableFunctionHandle"); + from_json_key( + j, + "functionName", + p.functionName, + "NativeTableFunctionHandle", + "QualifiedObjectName", + "functionName"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const NativeTableFunctionAnalysis& p) { + j = json::object(); + to_json_key( + j, + "returnedType", + p.returnedType, + "NativeTableFunctionAnalysis", + "NativeDescriptor", + "returnedType"); + to_json_key( + j, + "requiredColumns", + p.requiredColumns, + "NativeTableFunctionAnalysis", + "Map>", + "requiredColumns"); + to_json_key( + j, + "handle", + p.handle, + "NativeTableFunctionAnalysis", + "NativeTableFunctionHandle", + "handle"); +} + +void from_json(const json& j, NativeTableFunctionAnalysis& p) { + from_json_key( + j, + "returnedType", + p.returnedType, + "NativeTableFunctionAnalysis", + "NativeDescriptor", + "returnedType"); + from_json_key( + j, + "requiredColumns", + p.requiredColumns, + "NativeTableFunctionAnalysis", + "Map>", + "requiredColumns"); + from_json_key( + j, + "handle", + p.handle, + "NativeTableFunctionAnalysis", + "NativeTableFunctionHandle", + "handle"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +NativeTableFunctionSplit::NativeTableFunctionSplit() noexcept { + _type = "system:com.facebook.presto.tvf.NativeTableFunctionSplit"; +} + +void to_json(json& j, const NativeTableFunctionSplit& p) { + j = json::object(); + j["@type"] = "system:com.facebook.presto.tvf.NativeTableFunctionSplit"; + to_json_key( + j, + "serializedTableFunctionSplitHandle", + p.serializedTableFunctionSplitHandle, + "NativeTableFunctionSplit", + "String", + "serializedTableFunctionSplitHandle"); +} + +void from_json(const json& j, NativeTableFunctionSplit& p) { + p._type = j["@type"]; + from_json_key( + j, + "serializedTableFunctionSplitHandle", + p.serializedTableFunctionSplitHandle, + "NativeTableFunctionSplit", + "String", + "serializedTableFunctionSplitHandle"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { void to_json(json& j, const NodeLoadMetrics& p) { j = json::object(); @@ -7891,6 +8539,71 @@ void from_json(const json& j, NodeStatus& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { + +void to_json(json& j, const Serializable& p) { + j = json::object(); + to_json_key(j, "type", p.type, "Serializable", "Type", "type"); + to_json_key(j, "block", p.block, "Serializable", "Block", "block"); +} + +void from_json(const json& j, Serializable& p) { + from_json_key(j, "type", p.type, "Serializable", "Type", "type"); + from_json_key(j, "block", p.block, "Serializable", "Block", "block"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const NullableValue& p) { + j = json::object(); + to_json_key( + j, + "serializable", + p.serializable, + "NullableValue", + "Serializable", + "serializable"); +} + +void from_json(const json& j, NullableValue& p) { + from_json_key( + j, + "serializable", + p.serializable, + "NullableValue", + "Serializable", + "serializable"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +OnlyPassThroughReturnTypeSpecification:: + OnlyPassThroughReturnTypeSpecification() noexcept { + _type = "only_pass_through_table"; +} + +void to_json(json& j, const OnlyPassThroughReturnTypeSpecification& p) { + j = json::object(); + j["@type"] = "only_pass_through_table"; + to_json_key( + j, + "returnType", + p.returnType, + "OnlyPassThroughReturnTypeSpecification", + "String", + "returnType"); +} + +void from_json(const json& j, OnlyPassThroughReturnTypeSpecification& p) { + p._type = j["@type"]; + from_json_key( + j, + "returnType", + p.returnType, + "OnlyPassThroughReturnTypeSpecification", + "String", + "returnType"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { // Loosely copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays @@ -8156,6 +8869,80 @@ void from_json(const json& j, PartialAggregationStatsEstimate& p) { } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +void to_json(json& j, const PassThroughColumn& p) { + j = json::object(); + to_json_key( + j, + "outputVariables", + p.outputVariables, + "PassThroughColumn", + "VariableReferenceExpression", + "outputVariables"); + to_json_key( + j, + "partitioningColumn", + p.partitioningColumn, + "PassThroughColumn", + "bool", + "partitioningColumn"); +} + +void from_json(const json& j, PassThroughColumn& p) { + from_json_key( + j, + "outputVariables", + p.outputVariables, + "PassThroughColumn", + "VariableReferenceExpression", + "outputVariables"); + from_json_key( + j, + "partitioningColumn", + p.partitioningColumn, + "PassThroughColumn", + "bool", + "partitioningColumn"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const PassThroughSpecification& p) { + j = json::object(); + to_json_key( + j, + "declaredAsPassThrough", + p.declaredAsPassThrough, + "PassThroughSpecification", + "bool", + "declaredAsPassThrough"); + to_json_key( + j, + "columns", + p.columns, + "PassThroughSpecification", + "List", + "columns"); +} + +void from_json(const json& j, PassThroughSpecification& p) { + from_json_key( + j, + "declaredAsPassThrough", + p.declaredAsPassThrough, + "PassThroughSpecification", + "bool", + "declaredAsPassThrough"); + from_json_key( + j, + "columns", + p.columns, + "PassThroughSpecification", + "List", + "columns"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + void to_json(json& j, const PipelineStats& p) { j = json::object(); to_json_key( @@ -8705,17 +9492,14 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - StageExecutionStrategy_enum_table[] = - { // NOLINT: cert-err58-cpp - {StageExecutionStrategy::UNGROUPED_EXECUTION, - "UNGROUPED_EXECUTION"}, - {StageExecutionStrategy::FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, - "FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, - {StageExecutionStrategy:: - DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, - "DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, - {StageExecutionStrategy::RECOVERABLE_GROUPED_EXECUTION, - "RECOVERABLE_GROUPED_EXECUTION"}}; + StageExecutionStrategy_enum_table[] = { // NOLINT: cert-err58-cpp + {StageExecutionStrategy::UNGROUPED_EXECUTION, "UNGROUPED_EXECUTION"}, + {StageExecutionStrategy::FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, + "FIXED_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, + {StageExecutionStrategy::DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION, + "DYNAMIC_LIFESPAN_SCHEDULE_GROUPED_EXECUTION"}, + {StageExecutionStrategy::RECOVERABLE_GROUPED_EXECUTION, + "RECOVERABLE_GROUPED_EXECUTION"}}; void to_json(json& j, const StageExecutionStrategy& e) { static_assert( std::is_enum::value, @@ -9639,25 +10423,48 @@ void from_json(const json& j, RowNumberNode& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { -// Loosely copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() -// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays -static const std::pair RuntimeUnit_enum_table[] = - { // NOLINT: cert-err58-cpp - {RuntimeUnit::NONE, "NONE"}, - {RuntimeUnit::NANO, "NANO"}, - {RuntimeUnit::BYTE, "BYTE"}}; -void to_json(json& j, const RuntimeUnit& e) { - static_assert( - std::is_enum::value, "RuntimeUnit must be an enum!"); - const auto* it = std::find_if( - std::begin(RuntimeUnit_enum_table), - std::end(RuntimeUnit_enum_table), - [e](const std::pair& ej_pair) -> bool { - return ej_pair.first == e; - }); - j = ((it != std::end(RuntimeUnit_enum_table)) - ? it +void to_json(json& j, const RowType& p) { + j = json::object(); + to_json_key( + j, + "typeSignature", + p.typeSignature, + "RowType", + "TypeSignature", + "typeSignature"); +} + +void from_json(const json& j, RowType& p) { + from_json_key( + j, + "typeSignature", + p.typeSignature, + "RowType", + "TypeSignature", + "typeSignature"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +// Loosely copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + +// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays +static const std::pair RuntimeUnit_enum_table[] = + { // NOLINT: cert-err58-cpp + {RuntimeUnit::NONE, "NONE"}, + {RuntimeUnit::NANO, "NANO"}, + {RuntimeUnit::BYTE, "BYTE"}}; +void to_json(json& j, const RuntimeUnit& e) { + static_assert( + std::is_enum::value, "RuntimeUnit must be an enum!"); + const auto* it = std::find_if( + std::begin(RuntimeUnit_enum_table), + std::end(RuntimeUnit_enum_table), + [e](const std::pair& ej_pair) -> bool { + return ej_pair.first == e; + }); + j = ((it != std::end(RuntimeUnit_enum_table)) + ? it : std::begin(RuntimeUnit_enum_table)) ->second; } @@ -9771,6 +10578,69 @@ void from_json(const json& j, SampleNode& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +ScalarArgument::ScalarArgument() noexcept { + _type = "scalar"; +} + +void to_json(json& j, const ScalarArgument& p) { + j = json::object(); + j["@type"] = "scalar"; + to_json_key( + j, + "nullableValue", + p.nullableValue, + "ScalarArgument", + "NullableValue", + "nullableValue"); +} + +void from_json(const json& j, ScalarArgument& p) { + p._type = j["@type"]; + from_json_key( + j, + "nullableValue", + p.nullableValue, + "ScalarArgument", + "NullableValue", + "nullableValue"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +ScalarArgumentSpecification::ScalarArgumentSpecification() noexcept { + _type = "scalar"; +} + +void to_json(json& j, const ScalarArgumentSpecification& p) { + j = json::object(); + j["@type"] = "scalar"; + to_json_key( + j, "name", p.name, "ScalarArgumentSpecification", "String", "name"); + to_json_key(j, "type", p.type, "ScalarArgumentSpecification", "Type", "type"); + to_json_key( + j, + "required", + p.required, + "ScalarArgumentSpecification", + "bool", + "required"); +} + +void from_json(const json& j, ScalarArgumentSpecification& p) { + p._type = j["@type"]; + from_json_key( + j, "name", p.name, "ScalarArgumentSpecification", "String", "name"); + from_json_key( + j, "type", p.type, "ScalarArgumentSpecification", "Type", "type"); + from_json_key( + j, + "required", + p.required, + "ScalarArgumentSpecification", + "bool", + "required"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { // Loosely copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays @@ -10478,13 +11348,12 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - SystemPartitionFunction_enum_table[] = - { // NOLINT: cert-err58-cpp - {SystemPartitionFunction::SINGLE, "SINGLE"}, - {SystemPartitionFunction::HASH, "HASH"}, - {SystemPartitionFunction::ROUND_ROBIN, "ROUND_ROBIN"}, - {SystemPartitionFunction::BROADCAST, "BROADCAST"}, - {SystemPartitionFunction::UNKNOWN, "UNKNOWN"}}; + SystemPartitionFunction_enum_table[] = { // NOLINT: cert-err58-cpp + {SystemPartitionFunction::SINGLE, "SINGLE"}, + {SystemPartitionFunction::HASH, "HASH"}, + {SystemPartitionFunction::ROUND_ROBIN, "ROUND_ROBIN"}, + {SystemPartitionFunction::BROADCAST, "BROADCAST"}, + {SystemPartitionFunction::UNKNOWN, "UNKNOWN"}}; void to_json(json& j, const SystemPartitionFunction& e) { static_assert( std::is_enum::value, @@ -10521,14 +11390,13 @@ namespace facebook::presto::protocol { // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays static const std::pair - SystemPartitioning_enum_table[] = - { // NOLINT: cert-err58-cpp - {SystemPartitioning::SINGLE, "SINGLE"}, - {SystemPartitioning::FIXED, "FIXED"}, - {SystemPartitioning::SOURCE, "SOURCE"}, - {SystemPartitioning::SCALED, "SCALED"}, - {SystemPartitioning::COORDINATOR_ONLY, "COORDINATOR_ONLY"}, - {SystemPartitioning::ARBITRARY, "ARBITRARY"}}; + SystemPartitioning_enum_table[] = { // NOLINT: cert-err58-cpp + {SystemPartitioning::SINGLE, "SINGLE"}, + {SystemPartitioning::FIXED, "FIXED"}, + {SystemPartitioning::SOURCE, "SOURCE"}, + {SystemPartitioning::SCALED, "SCALED"}, + {SystemPartitioning::COORDINATOR_ONLY, "COORDINATOR_ONLY"}, + {SystemPartitioning::ARBITRARY, "ARBITRARY"}}; void to_json(json& j, const SystemPartitioning& e) { static_assert( std::is_enum::value, @@ -10831,6 +11699,563 @@ void from_json(const json& j, SystemTransactionHandle& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +TableArgument::TableArgument() noexcept { + _type = "table"; +} + +void to_json(json& j, const TableArgument& p) { + j = json::object(); + j["@type"] = "table"; + to_json_key(j, "rowType", p.rowType, "TableArgument", "RowType", "rowType"); + to_json_key( + j, + "partitionBy", + p.partitionBy, + "TableArgument", + "List", + "partitionBy"); + to_json_key( + j, "orderBy", p.orderBy, "TableArgument", "List", "orderBy"); + to_json_key(j, "fields", p.fields, "TableArgument", "List", "fields"); +} + +void from_json(const json& j, TableArgument& p) { + p._type = j["@type"]; + from_json_key(j, "rowType", p.rowType, "TableArgument", "RowType", "rowType"); + from_json_key( + j, + "partitionBy", + p.partitionBy, + "TableArgument", + "List", + "partitionBy"); + from_json_key( + j, "orderBy", p.orderBy, "TableArgument", "List", "orderBy"); + from_json_key( + j, "fields", p.fields, "TableArgument", "List", "fields"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const TableArgumentProperties& p) { + j = json::object(); + to_json_key( + j, + "argumentName", + p.argumentName, + "TableArgumentProperties", + "String", + "argumentName"); + to_json_key( + j, + "rowSemantics", + p.rowSemantics, + "TableArgumentProperties", + "bool", + "rowSemantics"); + to_json_key( + j, + "pruneWhenEmpty", + p.pruneWhenEmpty, + "TableArgumentProperties", + "bool", + "pruneWhenEmpty"); + to_json_key( + j, + "passThroughSpecification", + p.passThroughSpecification, + "TableArgumentProperties", + "PassThroughSpecification", + "passThroughSpecification"); + to_json_key( + j, + "requiredColumns", + p.requiredColumns, + "TableArgumentProperties", + "List", + "requiredColumns"); + to_json_key( + j, + "specification", + p.specification, + "TableArgumentProperties", + "DataOrganizationSpecification", + "specification"); +} + +void from_json(const json& j, TableArgumentProperties& p) { + from_json_key( + j, + "argumentName", + p.argumentName, + "TableArgumentProperties", + "String", + "argumentName"); + from_json_key( + j, + "rowSemantics", + p.rowSemantics, + "TableArgumentProperties", + "bool", + "rowSemantics"); + from_json_key( + j, + "pruneWhenEmpty", + p.pruneWhenEmpty, + "TableArgumentProperties", + "bool", + "pruneWhenEmpty"); + from_json_key( + j, + "passThroughSpecification", + p.passThroughSpecification, + "TableArgumentProperties", + "PassThroughSpecification", + "passThroughSpecification"); + from_json_key( + j, + "requiredColumns", + p.requiredColumns, + "TableArgumentProperties", + "List", + "requiredColumns"); + from_json_key( + j, + "specification", + p.specification, + "TableArgumentProperties", + "DataOrganizationSpecification", + "specification"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +TableArgumentSpecification::TableArgumentSpecification() noexcept { + _type = "table"; +} + +void to_json(json& j, const TableArgumentSpecification& p) { + j = json::object(); + j["@type"] = "table"; + to_json_key( + j, "name", p.name, "TableArgumentSpecification", "String", "name"); + to_json_key( + j, + "rowSemantics", + p.rowSemantics, + "TableArgumentSpecification", + "bool", + "rowSemantics"); + to_json_key( + j, + "pruneWhenEmpty", + p.pruneWhenEmpty, + "TableArgumentSpecification", + "bool", + "pruneWhenEmpty"); + to_json_key( + j, + "passThroughColumns", + p.passThroughColumns, + "TableArgumentSpecification", + "bool", + "passThroughColumns"); +} + +void from_json(const json& j, TableArgumentSpecification& p) { + p._type = j["@type"]; + from_json_key( + j, "name", p.name, "TableArgumentSpecification", "String", "name"); + from_json_key( + j, + "rowSemantics", + p.rowSemantics, + "TableArgumentSpecification", + "bool", + "rowSemantics"); + from_json_key( + j, + "pruneWhenEmpty", + p.pruneWhenEmpty, + "TableArgumentSpecification", + "bool", + "pruneWhenEmpty"); + from_json_key( + j, + "passThroughColumns", + p.passThroughColumns, + "TableArgumentSpecification", + "bool", + "passThroughColumns"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "system:com.facebook.presto.tvf.NativeTableFunctionHandle") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ConnectorTableFunctionHandle "); +} + +void from_json( + const json& j, + std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError( + std::string(e.what()) + + " ConnectorTableFunctionHandle ConnectorTableFunctionHandle"); + } + + if (type == "system:com.facebook.presto.tvf.NativeTableFunctionHandle") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ConnectorTableFunctionHandle "); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { + +void to_json(json& j, const TableFunctionHandle& p) { + j = json::object(); + to_json_key( + j, + "connectorId", + p.connectorId, + "TableFunctionHandle", + "ConnectorId", + "connectorId"); + to_json_key( + j, + "functionHandle", + p.functionHandle, + "TableFunctionHandle", + "ConnectorTableFunctionHandle", + "functionHandle"); + to_json_key( + j, + "transactionHandle", + p.transactionHandle, + "TableFunctionHandle", + "ConnectorTransactionHandle", + "transactionHandle"); +} + +void from_json(const json& j, TableFunctionHandle& p) { + from_json_key( + j, + "connectorId", + p.connectorId, + "TableFunctionHandle", + "ConnectorId", + "connectorId"); + from_json_key( + j, + "functionHandle", + p.functionHandle, + "TableFunctionHandle", + "ConnectorTableFunctionHandle", + "functionHandle"); + from_json_key( + j, + "transactionHandle", + p.transactionHandle, + "TableFunctionHandle", + "ConnectorTransactionHandle", + "transactionHandle"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +TableFunctionNode::TableFunctionNode() noexcept { + _type = "com.facebook.presto.sql.planner.plan.TableFunctionNode"; +} + +void to_json(json& j, const TableFunctionNode& p) { + j = json::object(); + j["@type"] = "com.facebook.presto.sql.planner.plan.TableFunctionNode"; + to_json_key(j, "id", p.id, "TableFunctionNode", "PlanNodeId", "id"); + to_json_key(j, "name", p.name, "TableFunctionNode", "String", "name"); + to_json_key( + j, + "arguments", + p.arguments, + "TableFunctionNode", + "Map>", + "arguments"); + to_json_key( + j, + "outputVariables", + p.outputVariables, + "TableFunctionNode", + "List", + "outputVariables"); + to_json_key( + j, + "sources", + p.sources, + "TableFunctionNode", + "List>", + "sources"); + to_json_key( + j, + "tableArgumentProperties", + p.tableArgumentProperties, + "TableFunctionNode", + "List", + "tableArgumentProperties"); + to_json_key( + j, + "copartitioningLists", + p.copartitioningLists, + "TableFunctionNode", + "List>", + "copartitioningLists"); + to_json_key( + j, + "handle", + p.handle, + "TableFunctionNode", + "TableFunctionHandle", + "handle"); +} + +void from_json(const json& j, TableFunctionNode& p) { + p._type = j["@type"]; + from_json_key(j, "id", p.id, "TableFunctionNode", "PlanNodeId", "id"); + from_json_key(j, "name", p.name, "TableFunctionNode", "String", "name"); + from_json_key( + j, + "arguments", + p.arguments, + "TableFunctionNode", + "Map>", + "arguments"); + from_json_key( + j, + "outputVariables", + p.outputVariables, + "TableFunctionNode", + "List", + "outputVariables"); + from_json_key( + j, + "sources", + p.sources, + "TableFunctionNode", + "List>", + "sources"); + from_json_key( + j, + "tableArgumentProperties", + p.tableArgumentProperties, + "TableFunctionNode", + "List", + "tableArgumentProperties"); + from_json_key( + j, + "copartitioningLists", + p.copartitioningLists, + "TableFunctionNode", + "List>", + "copartitioningLists"); + from_json_key( + j, + "handle", + p.handle, + "TableFunctionNode", + "TableFunctionHandle", + "handle"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +TableFunctionProcessorNode::TableFunctionProcessorNode() noexcept { + _type = "com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode"; +} + +void to_json(json& j, const TableFunctionProcessorNode& p) { + j = json::object(); + j["@type"] = + "com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode"; + to_json_key(j, "id", p.id, "TableFunctionProcessorNode", "PlanNodeId", "id"); + to_json_key( + j, "name", p.name, "TableFunctionProcessorNode", "String", "name"); + to_json_key( + j, + "properOutputs", + p.properOutputs, + "TableFunctionProcessorNode", + "List", + "properOutputs"); + to_json_key( + j, + "source", + p.source, + "TableFunctionProcessorNode", + "std::shared_ptr", + "source"); + to_json_key( + j, + "pruneWhenEmpty", + p.pruneWhenEmpty, + "TableFunctionProcessorNode", + "bool", + "pruneWhenEmpty"); + to_json_key( + j, + "passThroughSpecifications", + p.passThroughSpecifications, + "TableFunctionProcessorNode", + "List", + "passThroughSpecifications"); + to_json_key( + j, + "requiredVariables", + p.requiredVariables, + "TableFunctionProcessorNode", + "List>", + "requiredVariables"); + to_json_key( + j, + "markerVariables", + p.markerVariables, + "TableFunctionProcessorNode", + "Map", + "markerVariables"); + to_json_key( + j, + "specification", + p.specification, + "TableFunctionProcessorNode", + "DataOrganizationSpecification", + "specification"); + to_json_key( + j, + "prePartitioned", + p.prePartitioned, + "TableFunctionProcessorNode", + "List", + "prePartitioned"); + to_json_key( + j, + "preSorted", + p.preSorted, + "TableFunctionProcessorNode", + "int", + "preSorted"); + to_json_key( + j, + "hashSymbol", + p.hashSymbol, + "TableFunctionProcessorNode", + "VariableReferenceExpression", + "hashSymbol"); + to_json_key( + j, + "handle", + p.handle, + "TableFunctionProcessorNode", + "TableFunctionHandle", + "handle"); +} + +void from_json(const json& j, TableFunctionProcessorNode& p) { + p._type = j["@type"]; + from_json_key( + j, "id", p.id, "TableFunctionProcessorNode", "PlanNodeId", "id"); + from_json_key( + j, "name", p.name, "TableFunctionProcessorNode", "String", "name"); + from_json_key( + j, + "properOutputs", + p.properOutputs, + "TableFunctionProcessorNode", + "List", + "properOutputs"); + from_json_key( + j, + "source", + p.source, + "TableFunctionProcessorNode", + "std::shared_ptr", + "source"); + from_json_key( + j, + "pruneWhenEmpty", + p.pruneWhenEmpty, + "TableFunctionProcessorNode", + "bool", + "pruneWhenEmpty"); + from_json_key( + j, + "passThroughSpecifications", + p.passThroughSpecifications, + "TableFunctionProcessorNode", + "List", + "passThroughSpecifications"); + from_json_key( + j, + "requiredVariables", + p.requiredVariables, + "TableFunctionProcessorNode", + "List>", + "requiredVariables"); + from_json_key( + j, + "markerVariables", + p.markerVariables, + "TableFunctionProcessorNode", + "Map", + "markerVariables"); + from_json_key( + j, + "specification", + p.specification, + "TableFunctionProcessorNode", + "DataOrganizationSpecification", + "specification"); + from_json_key( + j, + "prePartitioned", + p.prePartitioned, + "TableFunctionProcessorNode", + "List", + "prePartitioned"); + from_json_key( + j, + "preSorted", + p.preSorted, + "TableFunctionProcessorNode", + "int", + "preSorted"); + from_json_key( + j, + "hashSymbol", + p.hashSymbol, + "TableFunctionProcessorNode", + "VariableReferenceExpression", + "hashSymbol"); + from_json_key( + j, + "handle", + p.handle, + "TableFunctionProcessorNode", + "TableFunctionHandle", + "handle"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { TableScanNode::TableScanNode() noexcept { _type = ".TableScanNode"; } diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index a189d8bfa18b9..7630d171d2b43 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -70,21 +70,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM; class Exception : public std::runtime_error { public: explicit Exception(const std::string& message) - : std::runtime_error(message) {}; + : std::runtime_error(message){}; }; class TypeError : public Exception { public: - explicit TypeError(const std::string& message) : Exception(message) {}; + explicit TypeError(const std::string& message) : Exception(message){}; }; class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message) {}; + explicit OutOfRange(const std::string& message) : Exception(message){}; }; class ParseError : public Exception { public: - explicit ParseError(const std::string& message) : Exception(message) {}; + explicit ParseError(const std::string& message) : Exception(message){}; }; using String = std::string; @@ -268,6 +268,16 @@ struct adl_serializer> { // Forward declaration of all abstract types // namespace facebook::presto::protocol { +struct ArgumentSpecification : public JsonEncodedSubclass {}; +void to_json(json& j, const std::shared_ptr& p); +void from_json(const json& j, std::shared_ptr& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct ReturnTypeSpecification : public JsonEncodedSubclass {}; +void to_json(json& j, const std::shared_ptr& p); +void from_json(const json& j, std::shared_ptr& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct FunctionHandle : public JsonEncodedSubclass {}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); @@ -297,6 +307,11 @@ void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct Argument : public JsonEncodedSubclass {}; +void to_json(json& j, const std::shared_ptr& p); +void from_json(const json& j, std::shared_ptr& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct InputDistribution : public JsonEncodedSubclass {}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); @@ -325,7 +340,22 @@ struct ConnectorMergeTableHandle : public JsonEncodedSubclass {}; void to_json(json& j, const std::shared_ptr& p); void from_json(const json& j, std::shared_ptr& p); } // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct ConnectorTableFunctionHandle : public JsonEncodedSubclass {}; +void to_json(json& j, const std::shared_ptr& p); +void from_json(const json& j, std::shared_ptr& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct AbstractConnectorTableFunction { + String schema = {}; + String name = {}; + List> arguments = {}; + std::shared_ptr returnTypeSpecification = {}; +}; +void to_json(json& j, const AbstractConnectorTableFunction& p); +void from_json(const json& j, AbstractConnectorTableFunction& p); +} // namespace facebook::presto::protocol namespace facebook::presto::protocol { struct SourceLocation { int line = {}; @@ -928,6 +958,14 @@ struct Column { void to_json(json& j, const Column& p); void from_json(const json& j, Column& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct ConnectorTableMetadata { + QualifiedObjectName functionName = {}; + Map> arguments = {}; +}; +void to_json(json& j, const ConnectorTableMetadata& p); +void from_json(const json& j, ConnectorTableMetadata& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { @@ -1045,6 +1083,50 @@ void to_json(json& j, const DeleteNode& p); void from_json(const json& j, DeleteNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct Field { + std::shared_ptr name = {}; + std::shared_ptr type = {}; +}; +void to_json(json& j, const Field& p); +void from_json(const json& j, Field& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct Descriptor { + List fields = {}; +}; +void to_json(json& j, const Descriptor& p); +void from_json(const json& j, Descriptor& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct DescribedTableReturnTypeSpecification : public ReturnTypeSpecification { + Descriptor descriptor = {}; + + DescribedTableReturnTypeSpecification() noexcept; +}; +void to_json(json& j, const DescribedTableReturnTypeSpecification& p); +void from_json(const json& j, DescribedTableReturnTypeSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct DescriptorArgument : public Argument { + std::shared_ptr descriptor = {}; + + DescriptorArgument() noexcept; +}; +void to_json(json& j, const DescriptorArgument& p); +void from_json(const json& j, DescriptorArgument& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct DescriptorArgumentSpecification : public ArgumentSpecification { + String name = {}; + bool required = {}; + Descriptor defaultValue = {}; + + DescriptorArgumentSpecification() noexcept; +}; +void to_json(json& j, const DescriptorArgumentSpecification& p); +void from_json(const json& j, DescriptorArgumentSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct DistinctLimitNode : public PlanNode { std::shared_ptr source = {}; int64_t limit = {}; @@ -1404,6 +1486,15 @@ void to_json(json& j, const Function& p); void from_json(const json& j, Function& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct GenericTableReturnTypeSpecification : public ReturnTypeSpecification { + String returnType = {}; + + GenericTableReturnTypeSpecification() noexcept; +}; +void to_json(json& j, const GenericTableReturnTypeSpecification& p); +void from_json(const json& j, GenericTableReturnTypeSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct GroupIdNode : public PlanNode { std::shared_ptr source = {}; List> groupingSets = {}; @@ -1689,6 +1780,15 @@ void to_json(json& j, const JoinNodeStatsEstimate& p); void from_json(const json& j, JoinNodeStatsEstimate& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct JsonBasedTableFunctionMetadata { + QualifiedObjectName functionName = {}; + List> arguments = {}; + std::shared_ptr returnTypeSpecification = {}; +}; +void to_json(json& j, const JsonBasedTableFunctionMetadata& p); +void from_json(const json& j, JsonBasedTableFunctionMetadata& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct JsonBasedUdfFunctionMetadata { String docString = {}; FunctionKind functionKind = {}; @@ -1849,6 +1949,21 @@ void to_json(json& j, const MergeTarget& p); void from_json(const json& j, MergeTarget& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct NativeField { + std::shared_ptr name = {}; + std::shared_ptr typeSignature = {}; +}; +void to_json(json& j, const NativeField& p); +void from_json(const json& j, NativeField& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct NativeDescriptor { + List fields = {}; +}; +void to_json(json& j, const NativeDescriptor& p); +void from_json(const json& j, NativeDescriptor& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct NativeSidecarFailureInfo { String type = {}; String message = {}; @@ -1861,6 +1976,34 @@ void to_json(json& j, const NativeSidecarFailureInfo& p); void from_json(const json& j, NativeSidecarFailureInfo& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct NativeTableFunctionHandle : public ConnectorTableFunctionHandle { + String serializedTableFunctionHandle = {}; + QualifiedObjectName functionName = {}; + + NativeTableFunctionHandle() noexcept; +}; +void to_json(json& j, const NativeTableFunctionHandle& p); +void from_json(const json& j, NativeTableFunctionHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct NativeTableFunctionAnalysis { + std::shared_ptr returnedType = {}; + Map> requiredColumns = {}; + NativeTableFunctionHandle handle = {}; +}; +void to_json(json& j, const NativeTableFunctionAnalysis& p); +void from_json(const json& j, NativeTableFunctionAnalysis& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct NativeTableFunctionSplit : public ConnectorSplit { + String serializedTableFunctionSplitHandle = {}; + + NativeTableFunctionSplit() noexcept; +}; +void to_json(json& j, const NativeTableFunctionSplit& p); +void from_json(const json& j, NativeTableFunctionSplit& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct NodeLoadMetrics { double cpuUsedPercent = {}; double memoryUsedInBytes = {}; @@ -1912,6 +2055,30 @@ void to_json(json& j, const NodeStatus& p); void from_json(const json& j, NodeStatus& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct Serializable { + Type type = {}; + Block block = {}; +}; +void to_json(json& j, const Serializable& p); +void from_json(const json& j, Serializable& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct NullableValue { + Serializable serializable = {}; +}; +void to_json(json& j, const NullableValue& p); +void from_json(const json& j, NullableValue& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct OnlyPassThroughReturnTypeSpecification : public ReturnTypeSpecification { + String returnType = {}; + + OnlyPassThroughReturnTypeSpecification() noexcept; +}; +void to_json(json& j, const OnlyPassThroughReturnTypeSpecification& p); +void from_json(const json& j, OnlyPassThroughReturnTypeSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { enum class BufferState { OPEN, NO_MORE_BUFFERS, @@ -1960,6 +2127,22 @@ void to_json(json& j, const PartialAggregationStatsEstimate& p); void from_json(const json& j, PartialAggregationStatsEstimate& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct PassThroughColumn { + VariableReferenceExpression outputVariables = {}; + bool partitioningColumn = {}; +}; +void to_json(json& j, const PassThroughColumn& p); +void from_json(const json& j, PassThroughColumn& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct PassThroughSpecification { + bool declaredAsPassThrough = {}; + List columns = {}; +}; +void to_json(json& j, const PassThroughSpecification& p); +void from_json(const json& j, PassThroughSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct PipelineStats { int pipelineId = {}; int64_t firstStartTimeInMillis = {}; @@ -2197,6 +2380,13 @@ void to_json(json& j, const RowNumberNode& p); void from_json(const json& j, RowNumberNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct RowType { + std::shared_ptr typeSignature = {}; +}; +void to_json(json& j, const RowType& p); +void from_json(const json& j, RowType& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { enum class RuntimeUnit { NONE, NANO, BYTE }; extern void to_json(json& j, const RuntimeUnit& e); extern void from_json(const json& j, RuntimeUnit& e); @@ -2230,6 +2420,26 @@ void to_json(json& j, const SampleNode& p); void from_json(const json& j, SampleNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct ScalarArgument : public Argument { + NullableValue nullableValue = {}; + + ScalarArgument() noexcept; +}; +void to_json(json& j, const ScalarArgument& p); +void from_json(const json& j, ScalarArgument& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct ScalarArgumentSpecification : public ArgumentSpecification { + String name = {}; + Type type = {}; + bool required = {}; + + ScalarArgumentSpecification() noexcept; +}; +void to_json(json& j, const ScalarArgumentSpecification& p); +void from_json(const json& j, ScalarArgumentSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { enum class DistributionType { PARTITIONED, REPLICATED }; extern void to_json(json& j, const DistributionType& e); extern void from_json(const json& j, DistributionType& e); @@ -2456,6 +2666,87 @@ void to_json(json& j, const SystemTransactionHandle& p); void from_json(const json& j, SystemTransactionHandle& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct TableArgument : public Argument { + RowType rowType = {}; + List partitionBy = {}; + List orderBy = {}; + List fields = {}; + + TableArgument() noexcept; +}; +void to_json(json& j, const TableArgument& p); +void from_json(const json& j, TableArgument& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct TableArgumentProperties { + String argumentName = {}; + bool rowSemantics = {}; + bool pruneWhenEmpty = {}; + PassThroughSpecification passThroughSpecification = {}; + List requiredColumns = {}; + std::shared_ptr specification = {}; +}; +void to_json(json& j, const TableArgumentProperties& p); +void from_json(const json& j, TableArgumentProperties& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct TableArgumentSpecification : public ArgumentSpecification { + String name = {}; + bool rowSemantics = {}; + bool pruneWhenEmpty = {}; + bool passThroughColumns = {}; + + TableArgumentSpecification() noexcept; +}; +void to_json(json& j, const TableArgumentSpecification& p); +void from_json(const json& j, TableArgumentSpecification& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct TableFunctionHandle { + ConnectorId connectorId = {}; + std::shared_ptr functionHandle = {}; + std::shared_ptr transactionHandle = {}; +}; +void to_json(json& j, const TableFunctionHandle& p); +void from_json(const json& j, TableFunctionHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct TableFunctionNode : public PlanNode { + String name = {}; + Map> arguments = {}; + List outputVariables = {}; + List> sources = {}; + List tableArgumentProperties = {}; + List> copartitioningLists = {}; + TableFunctionHandle handle = {}; + + TableFunctionNode() noexcept; +}; +void to_json(json& j, const TableFunctionNode& p); +void from_json(const json& j, TableFunctionNode& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { +struct TableFunctionProcessorNode : public PlanNode { + String name = {}; + List properOutputs = {}; + std::shared_ptr> source = {}; + bool pruneWhenEmpty = {}; + List passThroughSpecifications = {}; + List> requiredVariables = {}; + std::shared_ptr> + markerVariables = {}; + std::shared_ptr specification = {}; + List prePartitioned = {}; + int preSorted = {}; + std::shared_ptr hashSymbol = {}; + TableFunctionHandle handle = {}; + + TableFunctionProcessorNode() noexcept; +}; +void to_json(json& j, const TableFunctionProcessorNode& p); +void from_json(const json& j, TableFunctionProcessorNode& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct TableScanNode : public PlanNode { TableHandle table = {}; List outputVariables = {}; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index c02a4825f550c..5fc1ec7ae5c7b 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -123,6 +123,7 @@ AbstractClasses: - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } - { name: ArrowSplit, key: arrow-flight } + - { name: NativeTableFunctionSplit, key: system:com.facebook.presto.tvf.NativeTableFunctionSplit } ConnectorHistogram: super: JsonEncodedSubclass @@ -179,6 +180,8 @@ AbstractClasses: - { name: MergeJoinNode, key: .MergeJoinNode } - { name: WindowNode, key: .WindowNode } - { name: CallDistributedProcedureNode, key: com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode } + - { name: TableFunctionNode, key: com.facebook.presto.sql.planner.plan.TableFunctionNode } + - { name: TableFunctionProcessorNode, key: com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode } RowExpression: super: JsonEncodedSubclass @@ -205,6 +208,31 @@ AbstractClasses: - { name: SqlFunctionHandle, key: sql_function_handle } - { name: RestFunctionHandle, key: rest } + Argument: + super: JsonEncodedSubclass + subclasses: + - { name: DescriptorArgument, key: descriptor } + - { name: ScalarArgument, key: scalar } + - { name: TableArgument, key: table } + + ArgumentSpecification: + super: JsonEncodedSubclass + subclasses: + - { name: DescriptorArgumentSpecification, key: descriptor } + - { name: ScalarArgumentSpecification, key: scalar } + - { name: TableArgumentSpecification, key: table } + + ReturnTypeSpecification: + super: JsonEncodedSubclass + subclasses: + - { name: DescribedTableReturnTypeSpecification, key: described_table } + - { name: GenericTableReturnTypeSpecification, key: generic_table } + - { name: OnlyPassThroughReturnTypeSpecification, key: only_pass_through_table } + + ConnectorTableFunctionHandle: + super: JsonEncodedSubclass + subclasses: + - { name: NativeTableFunctionHandle, key: system:com.facebook.presto.tvf.NativeTableFunctionHandle } JavaClasses: - presto-spi/src/main/java/com/facebook/presto/spi/ErrorCause.java @@ -250,6 +278,9 @@ JavaClasses: - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java + - presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java + - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java + - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java - presto-spi/src/main/java/com/facebook/presto/spi/plan/TableWriterNode.java - presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableWriterMergeNode.java - presto-spi/src/main/java/com/facebook/presto/spi/plan/UnnestNode.java @@ -363,3 +394,26 @@ JavaClasses: - presto-spi/src/main/java/com/facebook/presto/spi/NodeStats.java - presto-spi/src/main/java/com/facebook/presto/spi/NodeLoadMetrics.java - presto-spi/src/main/java/com/facebook/presto/spi/session/SessionPropertyMetadata.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/Argument.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgument.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/Descriptor.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescribedTableReturnTypeSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/GenericTableReturnTypeSpecification.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/table/OnlyPassThroughReturnTypeSpecification.java + - presto-native-tvf/src/main/java/com/facebook/presto/tvf/ConnectorTableMetadata.java + - presto-common/src/main/java/com/facebook/presto/common/type/RowType.java + - presto-common/src/main/java/com/facebook/presto/common/predicate/NullableValue.java + - presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionHandle.java + - presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionSplit.java + - presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionAnalysis.java + - presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeDescriptor.java + - presto-native-tvf/src/main/java/com/facebook/presto/tvf/JsonBasedTableFunctionMetadata.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/java-to-struct-json.py b/presto-native-execution/presto_cpp/presto_protocol/java-to-struct-json.py index 68238a6b7ffa2..50daf5dfb84b8 100755 --- a/presto-native-execution/presto_cpp/presto_protocol/java-to-struct-json.py +++ b/presto-native-execution/presto_cpp/presto_protocol/java-to-struct-json.py @@ -40,7 +40,7 @@ language = { "cpp": { "TypeMap": { - r"([ ,<])(ColumnHandle|PlanNode|RowExpression|ConnectorDeleteTableHandle)([ ,>])": r"\1std::shared_ptr<\2>\3", + r"([ ,<])(ColumnHandle|PlanNode|RowExpression|ConnectorDeleteTableHandle|ArgumentSpecification|Argument)([ ,>])": r"\1std::shared_ptr<\2>\3", r"Optional": "Optional>", r"Optional": "Optional>", r"int\[\]": "List", @@ -52,6 +52,7 @@ r"Set<(.*)>": r"List<\1>", r"Optional<(.*)>": {"replace": r"\1", "flag": {"optional": True}}, r"ExchangeNode.Type": "ExchangeNodeType", + r"RowType.Field": "Field", } }, "pb": { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeTvfFunctions.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeTvfFunctions.java new file mode 100644 index 0000000000000..809d6aa88c28b --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeTvfFunctions.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tvf.TvfPlugin; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createRegion; + +@Test +public abstract class AbstractTestNativeTvfFunctions + extends AbstractTestQueryFramework +{ + @BeforeClass + @Override + public void init() throws Exception + { + super.init(); + getQueryRunner().installCoordinatorPlugin(new TvfPlugin()); + getQueryRunner().loadTVFProvider("system"); + } + + @Override + protected void createTables() + { + createRegion((QueryRunner) getExpectedQueryRunner()); + } + + @Test + public void testSequence() + { + assertQuery("SELECT * FROM TABLE(sequence( start => 20, stop => 100, step => 5))"); + } + + @Test + public void testExcludeColumns() + { + assertQuery("SELECT * FROM TABLE(exclude_columns(input => TABLE(region), columns => DESCRIPTOR(regionkey, comment)))"); + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java index a7a715f533883..85580dc6dd13c 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/PrestoNativeQueryRunnerUtils.java @@ -30,6 +30,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tvf.TvfPlugin; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; @@ -292,17 +293,21 @@ public QueryRunner build() externalWorkerLauncher = getExternalWorkerLauncher("hive", serverBinary, cacheMaxSize, remoteFunctionServerUds, pluginDirectory, failOnNestedLoopJoin, coordinatorSidecarEnabled, builtInWorkerFunctionsEnabled, enableRuntimeMetricsCollection, enableSsdCache, implicitCastCharNToVarchar); } - return HiveQueryRunner.createQueryRunner( - ImmutableList.of(), - ImmutableList.of(), - extraProperties, - extraCoordinatorProperties, - security, - hiveProperties, - Optional.ofNullable(workerCount), - Optional.of(Paths.get(addStorageFormatToPath ? dataDirectory.toString() + "/" + storageFormat : dataDirectory.toString())), - externalWorkerLauncher, - tpcdsProperties); + + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner( + ImmutableList.of(), + ImmutableList.of(), + extraProperties, + extraCoordinatorProperties, + security, + hiveProperties, + Optional.ofNullable(workerCount), + Optional.of(Paths.get(addStorageFormatToPath ? dataDirectory.toString() + "/" + storageFormat : dataDirectory.toString())), + externalWorkerLauncher, + tpcdsProperties); + //queryRunner.installCoordinatorPlugin(new TvfPlugin()); + //queryRunner.loadTVFProvider("system"); + return queryRunner; } } diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTvfFunctions.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTvfFunctions.java new file mode 100644 index 0000000000000..745e592f8ac5d --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTvfFunctions.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker; + +import com.facebook.presto.testing.ExpectedQueryRunner; +import com.facebook.presto.testing.QueryRunner; + +public class TestPrestoNativeTvfFunctions + extends AbstractTestNativeTvfFunctions +{ + @Override + protected QueryRunner createQueryRunner() throws Exception + { + return PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder() + .setAddStorageFormatToPath(true) + .build(); + } + + @Override + protected ExpectedQueryRunner createExpectedQueryRunner() throws Exception + { + return PrestoNativeQueryRunnerUtils.javaHiveQueryRunnerBuilder() + .setAddStorageFormatToPath(true) + .build(); + } +} diff --git a/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java b/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java index 434a1414ea1f0..6c99067f3f399 100644 --- a/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java +++ b/presto-native-tests/src/test/java/com/facebook/presto/nativetests/operator/scalar/AbstractTestNativeFunctions.java @@ -76,7 +76,7 @@ public void assertNotSupported(String projection, @Language("RegExp") String mes fail("expected exception"); } catch (RuntimeException ex) { - assertExceptionMessage(rewritten, ex, message, true); + assertExceptionMessage(rewritten, ex, message, true, false); } } diff --git a/presto-native-tvf/pom.xml b/presto-native-tvf/pom.xml new file mode 100644 index 0000000000000..e32ca900536ef --- /dev/null +++ b/presto-native-tvf/pom.xml @@ -0,0 +1,103 @@ + + +4.0.0 + + + com.facebook.presto + presto-root + 0.297-SNAPSHOT + + + presto-native-tvf + presto-native-tvf + Presto - Native table valued functions + presto-plugin + + + ${project.parent.basedir} + 17 + true + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + javax.inject + javax.inject + + + + com.facebook.airlift.drift + drift-api + provided + + + + io.airlift + slice + provided + + + + org.openjdk.jol + jol-core + provided + + + + com.facebook.airlift + units + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.facebook.airlift + json + + + + + com.facebook.presto + presto-spi + provided + + + + com.facebook.presto + presto-common + provided + + + com.facebook.airlift + bootstrap + + + com.facebook.airlift + http-client + + + com.facebook.presto + presto-main-base + + + diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ConnectorTableMetadata.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ConnectorTableMetadata.java new file mode 100644 index 0000000000000..97b41a2e3aa05 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ConnectorTableMetadata.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.function.table.Argument; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class ConnectorTableMetadata +{ + private final QualifiedObjectName functionName; + private final Map arguments; + + @JsonCreator + public ConnectorTableMetadata( + @JsonProperty("functionName") QualifiedObjectName functionName, + @JsonProperty("arguments") Map arguments) + { + this.functionName = requireNonNull(functionName, "functionName is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + } + + @JsonProperty("functionName") + public QualifiedObjectName getFunctionName() + { + return functionName; + } + + @JsonProperty + public Map getArguments() + { + return arguments; + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ForWorkerInfo.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ForWorkerInfo.java new file mode 100644 index 0000000000000..db71aeee08d73 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ForWorkerInfo.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; + +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@BindingAnnotation +public @interface ForWorkerInfo +{ +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/HttpClientHolder.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/HttpClientHolder.java new file mode 100644 index 0000000000000..14d65d99afc3c --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/HttpClientHolder.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.airlift.http.client.HttpClient; + +public class HttpClientHolder +{ + private static HttpClient httpClient; + + private HttpClientHolder() {} + + public static void setHttpClient(@ForWorkerInfo HttpClient httpClient) + { + HttpClientHolder.httpClient = httpClient; + } + + public static HttpClient getHttpClient() + { + return httpClient; + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/JsonBasedTableFunctionMetadata.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/JsonBasedTableFunctionMetadata.java new file mode 100644 index 0000000000000..9a86b3ad24a7a --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/JsonBasedTableFunctionMetadata.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class JsonBasedTableFunctionMetadata +{ + private final List arguments; + private final ReturnTypeSpecification returnTypeSpecification; + private final QualifiedObjectName functionName; + + @JsonCreator + public JsonBasedTableFunctionMetadata( + @JsonProperty("functionName") QualifiedObjectName functionName, + @JsonProperty("arguments") List arguments, + @JsonProperty("returnTypeSpecification") ReturnTypeSpecification returnTypeSpecification) + { + this.functionName = requireNonNull(functionName, "functionName is null"); + this.arguments = Collections.unmodifiableList(new ArrayList<>(requireNonNull(arguments, "arguments is null"))); + this.returnTypeSpecification = requireNonNull(returnTypeSpecification, "returnTypeSpecification is null"); + } + + @JsonProperty + public QualifiedObjectName getQualifiedObjectName() + { + return functionName; + } + + @JsonProperty + public List getArguments() + { + return arguments; + } + + @JsonProperty + public ReturnTypeSpecification getReturnTypeSpecification() + { + return returnTypeSpecification; + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeConnectorTableFunction.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeConnectorTableFunction.java new file mode 100644 index 0000000000000..b65d06cfdf042 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeConnectorTableFunction.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.block.BlockJsonSerde.Serializer; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; + +import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.preparePost; +import static com.facebook.presto.spi.StandardErrorCode.TABLE_FUNCTION_ANALYSIS_FAILED; +import static com.facebook.presto.tvf.NativeTVFProvider.getWorkerLocation; +import static com.google.common.net.HttpHeaders.ACCEPT; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.util.Objects.requireNonNull; + +public class NativeConnectorTableFunction + extends AbstractConnectorTableFunction +{ + private final HttpClient httpClient; + private final NodeManager nodeManager; + private final TypeManager typeManager; + private static final String TVF_ANALYZE_ENDPOINT = "/v1/tvf/analyze"; + private static final JsonCodec connectorTableMetadataJsonCodec; + private static final JsonCodec tableFunctionAnalysisJsonCodec = + JsonCodec.jsonCodec(NativeTableFunctionAnalysis.class); + private final QualifiedObjectName functionName; + + static { + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + provider.setJsonSerializers(ImmutableMap.of( + Block.class, new Serializer(new BlockEncodingManager()))); + + ObjectMapper mapper = provider.get(); + mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + connectorTableMetadataJsonCodec = codecFactory.jsonCodec(ConnectorTableMetadata.class); + } + + public NativeConnectorTableFunction( + @ForWorkerInfo HttpClient httpClient, + NodeManager nodeManager, + TypeManager typeManager, + QualifiedObjectName functionName, + List arguments, + ReturnTypeSpecification returnTypeSpecification) + { + super("builtin", functionName.getObjectName(), arguments, returnTypeSpecification); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + try { + return httpClient.execute( + getWorkerRequest(arguments), + createJsonResponseHandler(tableFunctionAnalysisJsonCodec)).toTableFunctionAnalysis(typeManager); + } + catch (Exception e) { + throw new PrestoException(TABLE_FUNCTION_ANALYSIS_FAILED, "Failed to analyze function.", e); + } + } + + private Request getWorkerRequest(Map arguments) + { + return preparePost() + .setUri(getWorkerLocation(nodeManager, TVF_ANALYZE_ENDPOINT)) + .setBodyGenerator( + jsonBodyGenerator(connectorTableMetadataJsonCodec, new ConnectorTableMetadata(functionName, arguments))) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .setHeader(ACCEPT, JSON_UTF_8.toString()) + .build(); + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeDescriptor.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeDescriptor.java new file mode 100644 index 0000000000000..0fd08fd2e2bf2 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeDescriptor.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.common.type.TypeSignature; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +public class NativeDescriptor +{ + private final List fields; + + @JsonCreator + public NativeDescriptor(@JsonProperty("fields") List fields) + { + requireNonNull(fields, "fields is null"); + checkArgument(!fields.isEmpty(), "descriptor has no fields"); + this.fields = unmodifiableList(fields); + } + + @JsonProperty + public List getFields() + { + return fields; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NativeDescriptor that = (NativeDescriptor) o; + return fields.equals(that.fields); + } + + @Override + public int hashCode() + { + return Objects.hash(fields); + } + + public static class NativeField + { + private final Optional name; + private final Optional typeSignature; + + @JsonCreator + public NativeField( + @JsonProperty("name") Optional name, + @JsonProperty("typeSignature") Optional typeSignature) + { + this.name = requireNonNull(name, "name is null"); + name.ifPresent(nameValue -> checkArgument(!nameValue.isEmpty(), "name is empty")); + this.typeSignature = requireNonNull(typeSignature, "typeSignature is null"); + } + + @JsonProperty + public Optional getName() + { + return name; + } + + @JsonProperty + public Optional getTypeSignature() + { + return typeSignature; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NativeField field = (NativeField) o; + return name.equals(field.name) && typeSignature.equals(field.typeSignature); + } + + @Override + public int hashCode() + { + return Objects.hash(name, typeSignature); + } + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProvider.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProvider.java new file mode 100644 index 0000000000000..882b04a18d6f3 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProvider.java @@ -0,0 +1,151 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.tvf.TVFProvider; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; + +import javax.inject.Inject; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.prepareGet; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static java.util.Objects.requireNonNull; + +public class NativeTVFProvider + implements TVFProvider +{ + private final NodeManager nodeManager; + private final TypeManager typeManager; + private final HttpClient httpClient; + private static final String TABLE_FUNCTIONS_ENDPOINT = "/v1/functions/tvf"; + private final JsonCodec> connectorTableFunctionListJsonCodec; + private final Supplier> memoizedTableFunctionsSupplier; + + @Inject + public NativeTVFProvider( + NodeManager nodeManager, + @ForWorkerInfo HttpClient httpClient, + TypeManager typeManager) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.memoizedTableFunctionsSupplier = Suppliers.memoizeWithExpiration(this::loadConnectorTableFunctions, + 100000, TimeUnit.MILLISECONDS); + + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + + provider.setJsonDeserializers(ImmutableMap.of( + Type.class, new TypeDeserializer(typeManager))); + + ObjectMapper mapper = provider.get(); + mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + this.connectorTableFunctionListJsonCodec = codecFactory.mapJsonCodec(String.class, JsonBasedTableFunctionMetadata.class); + } + + @Override + public List getTableFunctions() + { + return memoizedTableFunctionsSupplier.get(); + } + + public static URI getWorkerLocation(NodeManager nodeManager, String endpoint) + { + Set workerNodes = nodeManager.getWorkerNodes(); + if (workerNodes.isEmpty()) { + throw new IllegalStateException("No worker nodes available"); + } + Node workerNode = Iterables.get(workerNodes, new Random().nextInt(workerNodes.size())); + return HttpUriBuilder.uriBuilder() + .scheme("http") + .host(workerNode.getHost()) + .port(workerNode.getHostAndPort().getPort()) + .appendPath(endpoint) + .build(); + } + + private synchronized List loadConnectorTableFunctions() + { + Map connectorTableFunctions; + try { + Request request = prepareGet().setUri(getWorkerLocation(nodeManager, TABLE_FUNCTIONS_ENDPOINT)).build(); + connectorTableFunctions = httpClient.execute(request, createJsonResponseHandler(connectorTableFunctionListJsonCodec)); + } + catch (Exception e) { + throw new PrestoException(INVALID_ARGUMENTS, "Failed to get table functions from endpoint.", e); + } + + return connectorTableFunctions.values().stream().map(this::createNativeConnectorTableFunction).collect(ImmutableList.toImmutableList()); + } + + private synchronized NativeConnectorTableFunction createNativeConnectorTableFunction(JsonBasedTableFunctionMetadata connectorTableFunction) + { + return new NativeConnectorTableFunction( + httpClient, + nodeManager, + typeManager, + connectorTableFunction.getQualifiedObjectName(), + connectorTableFunction.getArguments(), + connectorTableFunction.getReturnTypeSpecification()); + } + + public static final class TypeDeserializer + extends FromStringDeserializer + { + private final TypeManager typeManager; + + @Inject + public TypeDeserializer(TypeManager typeManager) + { + super(Type.class); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + } + + @Override + protected Type _deserialize(String value, DeserializationContext context) + { + return typeManager.getType(parseTypeSignature(value)); + } + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProviderFactory.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProviderFactory.java new file mode 100644 index 0000000000000..df5302e7da503 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProviderFactory.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.tvf.TVFProvider; +import com.facebook.presto.spi.tvf.TVFProviderContext; +import com.facebook.presto.spi.tvf.TVFProviderFactory; +import com.google.inject.Injector; +import com.google.inject.Key; + +import java.util.Map; + +import static com.google.common.base.Throwables.throwIfUnchecked; + +/** + * Factory class to create instance of {@link NativeTVFProvider}. + * This factor is registered in {@link TvfPlugin#getTVFProviderFactories()} ()}. + */ +public class NativeTVFProviderFactory + implements TVFProviderFactory +{ + private static final String NAME = "system"; + + private static final NativeTableFunctionHandle.Resolver HANDLE_RESOLVER = new NativeTableFunctionHandle.Resolver(); + + private static final NativeTableFunctionSplit.Resolver SPLIT_RESOLVER = new NativeTableFunctionSplit.Resolver(); + + @Override + public String getName() + { + return NAME; + } + + @Override + public TableFunctionHandleResolver getTableFunctionHandleResolver() + { + return HANDLE_RESOLVER; + } + + @Override + public TableFunctionSplitResolver getTableFunctionSplitResolver() + { + return SPLIT_RESOLVER; + } + + @Override + public TVFProvider createTVFProvider(Map config, TVFProviderContext context) + { + try { + Bootstrap app = new Bootstrap( + new NativeTVFProviderModule(context.getNodeManager(), context.getTypeManager()), + new NativeWorkerCommunicationModule()); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + Key httpClientKey = Key.get(HttpClient.class, ForWorkerInfo.class); + HttpClientHolder.setHttpClient(injector.getInstance(httpClientKey)); + + return injector.getInstance(NativeTVFProvider.class); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProviderModule.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProviderModule.java new file mode 100644 index 0000000000000..918efe9eb2428 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTVFProviderModule.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.NodeManager; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static java.util.Objects.requireNonNull; + +public class NativeTVFProviderModule + implements Module +{ + private final NodeManager nodeManager; + private final TypeManager typeManager; + + public NativeTVFProviderModule(NodeManager nodeManager, TypeManager typeManager) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + } + + @Override + public void configure(Binder binder) + { + binder.bind(NodeManager.class).toInstance(nodeManager); + binder.bind(TypeManager.class).toInstance(typeManager); + binder.bind(NativeTVFProvider.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionAnalysis.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionAnalysis.java new file mode 100644 index 0000000000000..bcafc166bcf66 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionAnalysis.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class NativeTableFunctionAnalysis +{ + // a map from table argument name to list of column indexes for all columns required from the table argument + private final Map> requiredColumns; + + private final Optional returnedType; + private final NativeTableFunctionHandle handle; + + @JsonCreator + public NativeTableFunctionAnalysis( + @JsonProperty("returnedType") Optional returnedType, + @JsonProperty("requiredColumns") Map> requiredColumns, + @JsonProperty("handle") NativeTableFunctionHandle handle) + { + this.returnedType = requireNonNull(returnedType, "returnedType is null"); + this.requiredColumns = Collections.unmodifiableMap( + requiredColumns.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> Collections.unmodifiableList(entry.getValue())))); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public Optional getReturnedType() + { + return returnedType; + } + + @JsonProperty + public Map> getRequiredColumns() + { + return requiredColumns; + } + + @JsonProperty + public NativeTableFunctionHandle getHandle() + { + return handle; + } + + public TableFunctionAnalysis toTableFunctionAnalysis(TypeManager typeManager) + { + Descriptor descriptor = null; + if (returnedType.isPresent()) { + descriptor = new Descriptor( + convertToDescriptorFields(returnedType.get().getFields(), typeManager)); + } + TableFunctionAnalysis.Builder builder = TableFunctionAnalysis.builder(); + builder .returnedType(descriptor); + for (Map.Entry> entry : requiredColumns.entrySet()) { + builder .requiredColumns(entry.getKey(), entry.getValue()); + } + builder .handle(handle); + return builder.build(); + } + + private static List convertToDescriptorFields(List nativeFields, TypeManager typeManager) + { + return nativeFields.stream() + .map(field -> new Descriptor.Field( + field.getName(), + Optional.ofNullable(typeManager.getType(field.getTypeSignature().orElse(null))))) + .collect(toImmutableList()); + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionHandle.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionHandle.java new file mode 100644 index 0000000000000..16b2b8bdf7e30 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionHandle.java @@ -0,0 +1,135 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.index.IndexHandleJacksonModule; +import com.facebook.presto.metadata.ColumnHandleJacksonModule; +import com.facebook.presto.metadata.DeleteTableHandleJacksonModule; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.FunctionHandleJacksonModule; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.InsertTableHandleJacksonModule; +import com.facebook.presto.metadata.OutputTableHandleJacksonModule; +import com.facebook.presto.metadata.PartitioningHandleJacksonModule; +import com.facebook.presto.metadata.SplitJacksonModule; +import com.facebook.presto.metadata.TableFunctionJacksonHandleModule; +import com.facebook.presto.metadata.TableHandleJacksonModule; +import com.facebook.presto.metadata.TableLayoutHandleJacksonModule; +import com.facebook.presto.metadata.TransactionHandleJacksonModule; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; +import static com.facebook.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; +import static com.facebook.airlift.http.client.Request.Builder.preparePost; +import static com.facebook.presto.tvf.NativeTVFProvider.getWorkerLocation; +import static com.google.common.net.HttpHeaders.ACCEPT; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.util.Objects.requireNonNull; + +public class NativeTableFunctionHandle + implements ConnectorTableFunctionHandle +{ + private static final String TVF_SPLITS_ENDPOINT = "/v1/tvf/splits"; + + private final QualifiedObjectName functionName; + private final String serializedTableFunctionHandle; + + @JsonCreator + public NativeTableFunctionHandle( + @JsonProperty("serializedTableFunctionHandle") String serializedTableFunctionHandle, + @JsonProperty("functionName") QualifiedObjectName functionName) + { + this.serializedTableFunctionHandle = requireNonNull(serializedTableFunctionHandle, "serializedTableFunctionHandle is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); + } + + @JsonProperty + public String getSerializedTableFunctionHandle() + { + return serializedTableFunctionHandle; + } + + @JsonProperty("functionName") + public QualifiedObjectName getFunctionName() + { + return functionName; + } + + public static class Resolver + implements TableFunctionHandleResolver + { + @Override + public Set> getTableFunctionHandleClasses() + { + return ImmutableSet.of(NativeTableFunctionHandle.class); + } + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, NodeManager nodeManager, Object functionAndTypeManager) + { + if (functionAndTypeManager instanceof FunctionAndTypeManager) { + ObjectMapper objectMapper = new ObjectMapper(); + HandleResolver handleResolver = ((FunctionAndTypeManager) functionAndTypeManager).getHandleResolver(); + + FeaturesConfig featuresConfig = new FeaturesConfig(); + featuresConfig.setUseConnectorProvidedSerializationCodecs(false); + + objectMapper.registerModule(new TableHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new TableLayoutHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new ColumnHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new SplitJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new OutputTableHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new InsertTableHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new DeleteTableHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new IndexHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new TransactionHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new PartitioningHandleJacksonModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + objectMapper.registerModule(new FunctionHandleJacksonModule(handleResolver)); + objectMapper.registerModule(new TableFunctionJacksonHandleModule(handleResolver, featuresConfig, connectorId -> Optional.empty())); + JsonCodecFactory jsonCodecFactory = new JsonCodecFactory(() -> objectMapper); + JsonCodec nativeTableFunctionHandleCodec = jsonCodecFactory.jsonCodec(ConnectorTableFunctionHandle.class); + + return new FixedSplitSource( + HttpClientHolder.getHttpClient().execute( + preparePost() + .setUri(getWorkerLocation(nodeManager, TVF_SPLITS_ENDPOINT)) + .setBodyGenerator(jsonBodyGenerator(nativeTableFunctionHandleCodec, this)) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .setHeader(ACCEPT, JSON_UTF_8.toString()) + .build(), + createJsonResponseHandler(JsonCodec.listJsonCodec(NativeTableFunctionSplit.class)))); + } + + throw new UnsupportedOperationException(); + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionSplit.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionSplit.java new file mode 100644 index 0000000000000..922003ef1791f --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeTableFunctionSplit.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableSet; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class NativeTableFunctionSplit + implements ConnectorSplit +{ + private final String serializedTableFunctionSplitHandle; + + @JsonCreator + public NativeTableFunctionSplit( + @JsonProperty("serializedTableFunctionSplitHandle") String serializedTableFunctionSplitHandle) + { + this.serializedTableFunctionSplitHandle = requireNonNull(serializedTableFunctionSplitHandle, "serializedTableFunctionSplitHandle is null"); + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return null; + } + + @JsonProperty + public String getSerializedTableFunctionSplitHandle() + { + return serializedTableFunctionSplitHandle; + } + + public static class Resolver + implements TableFunctionSplitResolver + { + @Override + public Set> getTableFunctionSplitClasses() + { + return ImmutableSet.of(NativeTableFunctionSplit.class); + } + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeWorkerCommunicationModule.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeWorkerCommunicationModule.java new file mode 100644 index 0000000000000..0ae06fde478b3 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/NativeWorkerCommunicationModule.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.google.inject.Binder; +import com.google.inject.Module; + +import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; + +public class NativeWorkerCommunicationModule + implements Module +{ + @Override + public void configure(Binder binder) + { + httpClientBinder(binder).bindHttpClient("worker", ForWorkerInfo.class); + } +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ServingCatalog.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ServingCatalog.java new file mode 100644 index 0000000000000..de2f70f3e1776 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/ServingCatalog.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import javax.inject.Qualifier; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({FIELD, PARAMETER, METHOD}) +@Qualifier +public @interface ServingCatalog +{ +} diff --git a/presto-native-tvf/src/main/java/com/facebook/presto/tvf/TvfPlugin.java b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/TvfPlugin.java new file mode 100644 index 0000000000000..c24654bcd4102 --- /dev/null +++ b/presto-native-tvf/src/main/java/com/facebook/presto/tvf/TvfPlugin.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tvf; + +import com.facebook.presto.spi.CoordinatorPlugin; +import com.facebook.presto.spi.tvf.TVFProviderFactory; +import com.google.common.collect.ImmutableList; + +public class TvfPlugin + implements CoordinatorPlugin +{ + @Override + public Iterable getTVFProviderFactories() + { + return ImmutableList.of(new NativeTVFProviderFactory()); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 8d853545bb39e..2cdca571cc78a 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -1667,7 +1667,7 @@ public Node visitDescriptorArgument(SqlBaseParser.DescriptorArgumentContext cont @Override public Node visitDescriptorField(SqlBaseParser.DescriptorFieldContext context) { - return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.of(getType(context.type()))); + return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.ofNullable(context.type()).map(this::getType)); } /** diff --git a/presto-server/src/main/provisio/presto.xml b/presto-server/src/main/provisio/presto.xml index b7176785fedd4..e7d2f47d35625 100644 --- a/presto-server/src/main/provisio/presto.xml +++ b/presto-server/src/main/provisio/presto.xml @@ -377,4 +377,10 @@ + + + + + + diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java index 593538ad6a242..9853b3a5c8289 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/PrestoSparkRddFactory.java @@ -223,7 +223,7 @@ private JavaPairRDD cre Optional taskSourceRdd; List sources = findTableScanNodes(fragment.getRoot()); if (!sources.isEmpty()) { - try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits)) { + try (CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager)) { SplitSourceFactory splitSourceFactory = new SplitSourceFactory(splitSourceProvider, WarningCollector.NOOP); Map splitSources = splitSourceFactory.createSplitSources(fragment, session, tableWriteInfo); taskSourceRdd = Optional.of(createTaskSourcesRdd( diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java index 55bee3238e34d..5dca9ab76d47c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ConnectorHandleResolver.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; public interface ConnectorHandleResolver { @@ -65,4 +66,9 @@ default Class getTransactionHandleClass() { throw new UnsupportedOperationException(); } + + default Class getTableFunctionHandleClass() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java index 6b82a6ae2b64e..f263fd83f63a6 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/CoordinatorPlugin.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.plan.PlanCheckerProviderFactory; import com.facebook.presto.spi.session.WorkerSessionPropertyProviderFactory; import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.facebook.presto.spi.tvf.TVFProviderFactory; import com.facebook.presto.spi.type.TypeManagerFactory; import static java.util.Collections.emptyList; @@ -52,4 +53,9 @@ default Iterable getTypeManagerFactories() { return emptyList(); } + + default Iterable getTVFProviderFactories() + { + return emptyList(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java index 819837b524233..21366d8c99968 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java @@ -80,6 +80,7 @@ public enum StandardErrorCode DATATYPE_MISMATCH(0x0000_0036, USER_ERROR), SESSION_CATALOG_NOT_SET(0x0000_0037, USER_ERROR), MV_MISSING_TOO_MUCH_DATA(0x0000_0038, USER_ERROR), + TABLE_FUNCTION_ANALYSIS_FAILED(0x0000_0038, USER_ERROR), GENERIC_INTERNAL_ERROR(0x0001_0000, INTERNAL_ERROR), TOO_MANY_REQUESTS_FAILED(0x0001_0001, INTERNAL_ERROR, true), @@ -152,6 +153,7 @@ public enum StandardErrorCode DUPLICATE_FUNCTION_ERROR(0x0002_0016, INTERNAL_ERROR), MEMORY_ARBITRATION_FAILURE(0x0002_0017, INSUFFICIENT_RESOURCES), AUTHENTICATOR_NOT_APPLICABLE(0x0002_0018, INTERNAL_ERROR), + TABLE_FUNCTION_NOT_FOUND(0x0002_0019, USER_ERROR), /**/; // Error code range 0x0003 is reserved for Presto-on-Spark diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java index d2d0ce779a77e..5fb86e8926ff0 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/Connector.java @@ -15,6 +15,8 @@ import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.session.PropertyMetadata; @@ -22,6 +24,7 @@ import java.util.List; import java.util.Set; +import java.util.function.Function; import static com.facebook.presto.spi.connector.EmptyConnectorCommitHandle.INSTANCE; import static java.util.Collections.emptyList; @@ -127,6 +130,16 @@ default Set getTableFunctions() return emptySet(); } + /** + * @return the table function processor provider for the connector + */ + default Function getTableFunctionProcessorProvider() + { + return handle -> { + throw new UnsupportedOperationException(); + }; + } + /** * @return the set of functions provided by this connector */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java index 98e71be2e266d..60172b3ce9cec 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorCodecProvider.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import java.util.Optional; @@ -84,6 +85,11 @@ default Optional> getConnectorIndexHandleCo return Optional.empty(); } + default Optional> getConnectorTableFunctionHandleCodec() + { + return Optional.empty(); + } + default Optional> getConnectorDistributedProcedureHandleCodec() { return Optional.empty(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java index 07b36b4dca528..22e658686adee 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorFactory.java @@ -14,8 +14,14 @@ package com.facebook.presto.spi.connector; import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import java.util.Map; +import java.util.Optional; +import java.util.function.Function; public interface ConnectorFactory { @@ -24,4 +30,19 @@ public interface ConnectorFactory ConnectorHandleResolver getHandleResolver(); Connector create(String catalogName, Map config, ConnectorContext context); + + default Function getTableFunctionProcessorProvider() + { + return null; + } + + default Optional getTableFunctionHandleResolver() + { + return Optional.empty(); + } + + default Optional getTableFunctionSplitResolver() + { + return Optional.empty(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java index 69ac79c9f7522..8736640c5f41d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorSplitManager.java @@ -17,6 +17,7 @@ import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -71,4 +72,12 @@ public WarningCollector getWarningCollector() return warningCollector; } } + + default ConnectorSplitSource getSplits( + ConnectorTransactionHandle transaction, + ConnectorSession session, + ConnectorTableFunctionHandle function) + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java index 4efb85e07c088..815ba36ff4974 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorSplitManager.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -41,4 +42,12 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand return delegate.getSplits(transactionHandle, session, layout, splitSchedulingContext); } } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle function) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getSplits(transaction, session, function); + } + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java new file mode 100644 index 0000000000000..fd24b9c694c50 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionHandleResolver.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; + +import java.util.Set; + +public interface TableFunctionHandleResolver +{ + Set> getTableFunctionHandleClasses(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java new file mode 100644 index 0000000000000..2a31b1a9aa113 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/TableFunctionSplitResolver.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import com.facebook.presto.spi.ConnectorSplit; + +import java.util.Set; + +public interface TableFunctionSplitResolver +{ + Set> getTableFunctionSplitClasses(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java index 1be3f816190f4..30655cb25450b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/AbstractConnectorTableFunction.java @@ -15,6 +15,8 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; import java.util.Collections; @@ -44,7 +46,12 @@ public abstract class AbstractConnectorTableFunction private final List arguments; private final ReturnTypeSpecification returnTypeSpecification; - public AbstractConnectorTableFunction(String schema, String name, List arguments, ReturnTypeSpecification returnTypeSpecification) + @JsonCreator + public AbstractConnectorTableFunction( + @JsonProperty("schema") String schema, + @JsonProperty("name") String name, + @JsonProperty("arguments") List arguments, + @JsonProperty("returnTypeSpecification") ReturnTypeSpecification returnTypeSpecification) { this.schema = requireNonNull(schema, "schema is null"); this.name = requireNonNull(name, "name is null"); @@ -52,24 +59,28 @@ public AbstractConnectorTableFunction(String schema, String name, List getArguments() { return arguments; } + @JsonProperty @Override public ReturnTypeSpecification getReturnTypeSpecification() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java index 73c822095863f..404916b239f7d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ArgumentSpecification.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.spi.function.table; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import jakarta.annotation.Nullable; import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; @@ -28,6 +31,14 @@ *

* Default values are allowed for all arguments except Table arguments. */ +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "@type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = DescriptorArgumentSpecification.class, name = "descriptor"), + @JsonSubTypes.Type(value = TableArgumentSpecification.class, name = "table"), + @JsonSubTypes.Type(value = ScalarArgumentSpecification.class, name = "scalar")}) public abstract class ArgumentSpecification { public static final String argumentType = "Abstract"; @@ -45,16 +56,19 @@ public abstract class ArgumentSpecification this.defaultValue = defaultValue; } + @JsonProperty public String getName() { return name; } + @JsonProperty public boolean isRequired() { return required; } + @JsonProperty public Object getDefaultValue() { return defaultValue; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java index 6b535309b6870..58933b68dd5f8 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ConnectorTableFunctionHandle.java @@ -13,6 +13,10 @@ */ package com.facebook.presto.spi.function.table; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.fasterxml.jackson.annotation.JsonInclude; /** @@ -21,4 +25,11 @@ @JsonInclude(JsonInclude.Include.ALWAYS) public interface ConnectorTableFunctionHandle { + default ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, + ConnectorSession session, + NodeManager nodeManager, + Object functionAndTypeManager) + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescribedTableReturnTypeSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescribedTableReturnTypeSpecification.java new file mode 100644 index 0000000000000..0798ba2cde237 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescribedTableReturnTypeSpecification.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * The proper columns of the table function are known at function declaration time. + * They do not depend on the actual call arguments. + */ +public class DescribedTableReturnTypeSpecification + extends ReturnTypeSpecification +{ + private final Descriptor descriptor; + private static final String returnType = "DESCRIBED"; + + @JsonCreator + public DescribedTableReturnTypeSpecification(@JsonProperty("descriptor") Descriptor descriptor) + { + requireNonNull(descriptor, "descriptor is null"); + checkArgument(descriptor.isTyped(), "field types not specified"); + this.descriptor = descriptor; + } + + public Descriptor getDescriptor() + { + return descriptor; + } + + @Override + public String getReturnType() + { + return returnType; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java index f8aa9bab06408..ce7474c87e908 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgument.java @@ -32,7 +32,7 @@ public class DescriptorArgument private final Optional descriptor; @JsonCreator - private DescriptorArgument(@JsonProperty("descriptor") Optional descriptor) + public DescriptorArgument(@JsonProperty("descriptor") Optional descriptor) { this.descriptor = requireNonNull(descriptor, "descriptor is null"); descriptor.ifPresent(descriptorValue -> checkArgument( diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java index 11ed93f02cd00..221e0c36b7665 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/DescriptorArgumentSpecification.java @@ -13,11 +13,18 @@ */ package com.facebook.presto.spi.function.table; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + public class DescriptorArgumentSpecification extends ArgumentSpecification { public static final String argumentType = "DescriptorArgumentSpecification"; - private DescriptorArgumentSpecification(String name, boolean required, Descriptor defaultValue) + @JsonCreator + public DescriptorArgumentSpecification( + @JsonProperty("name") String name, + @JsonProperty("required") boolean required, + @JsonProperty("defaultValue") Descriptor defaultValue) { super(name, required, defaultValue); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/GenericTableReturnTypeSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/GenericTableReturnTypeSpecification.java new file mode 100644 index 0000000000000..4cd17b932026b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/GenericTableReturnTypeSpecification.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * The proper columns of the table function are not known at function declaration time. + * They must be determined at query analysis time based on the actual call arguments. + */ +public class GenericTableReturnTypeSpecification + extends ReturnTypeSpecification +{ + public static final GenericTableReturnTypeSpecification GENERIC_TABLE = new GenericTableReturnTypeSpecification(""); + private static final String returnType = "GENERIC"; + + @JsonCreator + public GenericTableReturnTypeSpecification(@JsonProperty("returnType") String returnType) + {} + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GenericTableReturnTypeSpecification that = (GenericTableReturnTypeSpecification) o; + return Objects.equals(returnType, that.returnType); + } + + @Override + public int hashCode() + { + return Objects.hash(returnType); + } + + @Override + public String getReturnType() + { + return returnType; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/OnlyPassThroughReturnTypeSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/OnlyPassThroughReturnTypeSpecification.java new file mode 100644 index 0000000000000..f810628a95248 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/OnlyPassThroughReturnTypeSpecification.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The table function has no proper columns. + */ +public class OnlyPassThroughReturnTypeSpecification + extends ReturnTypeSpecification +{ + public static final OnlyPassThroughReturnTypeSpecification ONLY_PASS_THROUGH = new OnlyPassThroughReturnTypeSpecification(""); + private static final String returnType = "PASSTRHOUGH"; + + @JsonCreator + public OnlyPassThroughReturnTypeSpecification(@JsonProperty("returnType") String returnType) + { + } + + @Override + public String getReturnType() + { + return returnType; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java index 2c77503cadada..db45496f4e7e4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ReturnTypeSpecification.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.spi.function.table; -import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; /** * The return type declaration refers to the proper columns of the table function. @@ -25,77 +25,15 @@ * dynamically determined at analysis time (GenericTable), or simply passed through * from input tables without adding new columns (OnlyPassThrough). */ +@JsonTypeInfo( + use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "@type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = GenericTableReturnTypeSpecification.class, name = "generic_table"), + @JsonSubTypes.Type(value = OnlyPassThroughReturnTypeSpecification.class, name = "only_pass_through_table"), + @JsonSubTypes.Type(value = DescribedTableReturnTypeSpecification.class, name = "described_table")}) public abstract class ReturnTypeSpecification { - public static final String returnType = "Abstract"; - /** - * The proper columns of the table function are not known at function declaration time. - * They must be determined at query analysis time based on the actual call arguments. - */ - public static class GenericTable - extends ReturnTypeSpecification - { - public static final String returnType = "GenericTable"; - public static final GenericTable GENERIC_TABLE = new GenericTable(); - - private GenericTable() {} - - @Override - public String getReturnType() - { - return returnType; - } - } - - /** - * The table function has no proper columns. - */ - public static class OnlyPassThrough - extends ReturnTypeSpecification - { - public static final String returnType = "OnlyPassThrough"; - public static final OnlyPassThrough ONLY_PASS_THROUGH = new OnlyPassThrough(); - - private OnlyPassThrough() {} - - @Override - public String getReturnType() - { - return returnType; - } - } - - /** - * The proper columns of the table function are known at function declaration time. - * They do not depend on the actual call arguments. - */ - public static class DescribedTable - extends ReturnTypeSpecification - { - public static final String returnType = "DescribedTable"; - private final Descriptor descriptor; - - public DescribedTable(Descriptor descriptor) - { - requireNonNull(descriptor, "descriptor is null"); - checkArgument(descriptor.isTyped(), "field types not specified"); - this.descriptor = descriptor; - } - - public Descriptor getDescriptor() - { - return descriptor; - } - - @Override - public String getReturnType() - { - return returnType; - } - } - - public String getReturnType() - { - return returnType; - } + public abstract String getReturnType(); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java index 94f98bafe949a..fc9580c829413 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/ScalarArgumentSpecification.java @@ -15,6 +15,9 @@ import com.facebook.presto.common.predicate.Primitives; import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.annotation.Nullable; import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; import static java.lang.String.format; @@ -26,7 +29,12 @@ public class ScalarArgumentSpecification public static final String argumentType = "ScalarArgumentSpecification"; private final Type type; - private ScalarArgumentSpecification(String name, Type type, boolean required, Object defaultValue) + @JsonCreator + public ScalarArgumentSpecification( + @JsonProperty("name") String name, + @JsonProperty("type") Type type, + @JsonProperty("required") boolean required, + @Nullable Object defaultValue) { super(name, required, defaultValue); this.type = requireNonNull(type, "type is null"); @@ -35,6 +43,7 @@ private ScalarArgumentSpecification(String name, Type type, boolean required, Ob } } + @JsonProperty public Type getType() { return type; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java index 4526f80fd5b76..fbfc9d1fab98e 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgument.java @@ -36,7 +36,8 @@ public class TableArgument public TableArgument( @JsonProperty("rowType") RowType rowType, @JsonProperty("partitionBy") List partitionBy, - @JsonProperty("orderBy") List orderBy) + @JsonProperty("orderBy") List orderBy, + @JsonProperty("fields") List fields) { this.rowType = requireNonNull(rowType, "rowType is null"); this.partitionBy = requireNonNull(partitionBy, "partitionBy is null"); @@ -49,6 +50,12 @@ public RowType getRowType() return rowType; } + @JsonProperty + public List getFields() + { + return rowType.getFields(); + } + @JsonProperty public List getPartitionBy() { @@ -106,7 +113,7 @@ public Builder orderBy(List orderBy) public TableArgument build() { - return new TableArgument(rowType, partitionBy, orderBy); + return new TableArgument(rowType, partitionBy, orderBy, rowType.getFields()); } } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java index 44c74258152dd..3db545d53e1b2 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableArgumentSpecification.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.spi.function.table; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -24,7 +27,12 @@ public class TableArgumentSpecification private final boolean pruneWhenEmpty; private final boolean passThroughColumns; - private TableArgumentSpecification(String name, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns) + @JsonCreator + public TableArgumentSpecification( + @JsonProperty("name") String name, + @JsonProperty("rowSemantics") boolean rowSemantics, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughColumns") boolean passThroughColumns) { super(name, true, null); @@ -36,16 +44,19 @@ private TableArgumentSpecification(String name, boolean rowSemantics, boolean pr this.passThroughColumns = passThroughColumns; } + @JsonProperty public boolean isRowSemantics() { return rowSemantics; } + @JsonProperty public boolean isPruneWhenEmpty() { return pruneWhenEmpty; } + @JsonProperty public boolean isPassThroughColumns() { return passThroughColumns; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java new file mode 100644 index 0000000000000..8a5b176e60ae7 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionDataProcessor.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.Page; + +import java.util.List; +import java.util.Optional; + +public interface TableFunctionDataProcessor +{ + /** + * This method processes a portion of data. It is called multiple times until the partition is fully processed. + * + * @param input a tuple of {@link Page} including one page for each table function's input table. + * Pages list is ordered according to the corresponding argument specifications in {@link ConnectorTableFunction}. + * A page for an argument consists of columns requested during analysis (see {@link TableFunctionAnalysis#getRequiredColumns()}}. + * If any of the sources is fully processed, {@code Optional.empty)()} is returned for that source. + * If all sources are fully processed, the argument is {@code null}. + * @return {@link TableFunctionProcessorState} including the processor's state and optionally a portion of result. + * After the returned state is {@code FINISHED}, the method will not be called again. + */ + TableFunctionProcessorState process(List> input); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionMetadata.java similarity index 91% rename from presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java rename to presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionMetadata.java index 806215927b736..4a8007bdd467b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionMetadata.java @@ -11,10 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.metadata; +package com.facebook.presto.spi.function.table; import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.function.table.ConnectorTableFunction; import static java.util.Objects.requireNonNull; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java new file mode 100644 index 0000000000000..556e3828eb79c --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorProvider.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +public interface TableFunctionProcessorProvider +{ + /** + * This method returns a {@code TableFunctionDataProcessor}. All the necessary information collected during analysis is available + * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each partition processed by the table function. + */ + default TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + throw new UnsupportedOperationException("this table function does not process input data"); + } + + /** + * This method returns a {@code TableFunctionSplitProcessor}. All the necessary information collected during analysis is available + * in the form of {@link ConnectorTableFunctionHandle}. It is called once per each split processed by the table function. + */ + default TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + throw new UnsupportedOperationException("this table function does not process splits"); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java new file mode 100644 index 0000000000000..581407fb46633 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionProcessorState.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.common.Page; +import jakarta.annotation.Nullable; + +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +/** + * The result of processing input by {@link TableFunctionDataProcessor} or {@link TableFunctionSplitProcessor}. + * It can optionally include a portion of output data in the form of {@link Page} + * The returned {@link Page} should consist of: + * - proper columns produced by the table function + * - one column of type {@code BIGINT} for each table function's input table having the pass-through property (see {@link TableArgumentSpecification#isPassThroughColumns}), + * in order of the corresponding argument specifications. Entries in these columns are the indexes of input rows (from partition start) to be attached to output, + * or null to indicate that a row of nulls should be attached instead of an input row. The indexes are validated to be within the portion of the partition + * provided to the function so far. + * Note: when the input is empty, the only valid index value is null, because there are no input rows that could be attached to output. In such case, for performance + * reasons, the validation of indexes is skipped, and all pass-through columns are filled with nulls. + */ +public interface TableFunctionProcessorState +{ + final class Blocked + implements TableFunctionProcessorState + { + private final CompletableFuture future; + + private Blocked(CompletableFuture future) + { + this.future = requireNonNull(future, "future is null"); + } + + public static Blocked blocked(CompletableFuture future) + { + return new Blocked(future); + } + + public CompletableFuture getFuture() + { + return future; + } + } + + final class Finished + implements TableFunctionProcessorState + { + public static final Finished FINISHED = new Finished(); + + private Finished() {} + } + + final class Processed + implements TableFunctionProcessorState + { + private final boolean usedInput; + private final Page result; + + private Processed(boolean usedInput, @Nullable Page result) + { + this.usedInput = usedInput; + this.result = result; + } + + public static Processed usedInput() + { + return new Processed(true, null); + } + + public static Processed produced(Page result) + { + requireNonNull(result, "result is null"); + return new Processed(false, result); + } + + public static Processed usedInputAndProduced(Page result) + { + requireNonNull(result, "result is null"); + return new Processed(true, result); + } + + public boolean isUsedInput() + { + return usedInput; + } + + public Page getResult() + { + return result; + } + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java new file mode 100644 index 0000000000000..504ea54fcb61f --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/table/TableFunctionSplitProcessor.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function.table; + +import com.facebook.presto.spi.ConnectorSplit; + +public interface TableFunctionSplitProcessor +{ + /** + * This method processes a split. It is called multiple times until the whole output for the split is produced. + * + * @param split a {@link ConnectorSplit} representing a subtask. + * @return {@link TableFunctionProcessorState} including the processor's state and optionally a portion of result. + * After the returned state is {@code FINISHED}, the method will not be called again. + */ + TableFunctionProcessorState process(ConnectorSplit split); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProvider.java new file mode 100644 index 0000000000000..e4c472622c46b --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProvider.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.tvf; + +import com.facebook.presto.spi.function.table.ConnectorTableFunction; + +import java.util.List; + +public interface TVFProvider +{ + List getTableFunctions(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProviderContext.java b/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProviderContext.java new file mode 100644 index 0000000000000..ea55aa464d111 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProviderContext.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.tvf; + +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.NodeManager; + +import static java.util.Objects.requireNonNull; + +public class TVFProviderContext +{ + private final NodeManager nodeManager; + private final TypeManager typeManager; + + public TVFProviderContext(NodeManager nodeManager, TypeManager typeManager) + { + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + } + + public NodeManager getNodeManager() + { + return nodeManager; + } + + public TypeManager getTypeManager() + { + return typeManager; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProviderFactory.java b/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProviderFactory.java new file mode 100644 index 0000000000000..3f7e60d033d20 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/tvf/TVFProviderFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.tvf; + +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; + +import java.util.Map; + +public interface TVFProviderFactory +{ + TVFProvider createTVFProvider(Map config, TVFProviderContext context); + + TableFunctionHandleResolver getTableFunctionHandleResolver(); + + TableFunctionSplitResolver getTableFunctionSplitResolver(); + + String getName(); +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 3f95501b03cd9..dac8717f95c5e 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -323,6 +323,11 @@ protected void assertQueryFails(@Language("SQL") String sql, @Language("RegExp") QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp); } + protected void assertQueryFailsExact(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) + { + QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp, false, true); + } + protected void assertQueryFails(QueryRunner queryRunner, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp); @@ -330,7 +335,7 @@ protected void assertQueryFails(QueryRunner queryRunner, @Language("SQL") String protected void assertQueryFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher) { - QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp, usePatternMatcher); + QueryAssertions.assertQueryFails(queryRunner, getSession(), sql, expectedMessageRegExp, usePatternMatcher, false); } protected void assertQueryFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) @@ -355,7 +360,7 @@ protected void assertQueryError(@Language("SQL") String sql, @Language("RegExp") protected void assertQueryFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher) { - QueryAssertions.assertQueryFails(queryRunner, session, sql, expectedMessageRegExp, usePatternMatcher); + QueryAssertions.assertQueryFails(queryRunner, session, sql, expectedMessageRegExp, usePatternMatcher, false); } protected void assertQueryReturnsEmptyResult(@Language("SQL") String sql) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index b9f88018332f8..62d2a477f897f 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -1090,6 +1090,16 @@ public void registerNativeFunctions() } } + @Override + public void loadTVFProvider(String tvfProviderName) + { + for (TestingPrestoServer server : servers) { + server.getMetadata().getFunctionAndTypeManager().loadTVFProvider( + tvfProviderName, + server.getPluginNodeManager()); + } + } + private static void closeUnchecked(AutoCloseable closeable) { try { diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java index 207a52ed95bce..ed78d28aaf6ed 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/QueryAssertions.java @@ -381,18 +381,18 @@ protected static void assertQueryFails(QueryRunner queryRunner, Session session, fail(format("Expected query to fail: %s", sql)); } catch (RuntimeException ex) { - assertExceptionMessage(sql, ex, expectedMessageRegExp, false); + assertExceptionMessage(sql, ex, expectedMessageRegExp, false, false); } } - protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher) + protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp, boolean usePatternMatcher, boolean exact) { try { queryRunner.execute(session, sql); fail(format("Expected query to fail: %s", sql)); } catch (RuntimeException ex) { - assertExceptionMessage(sql, ex, expectedMessageRegExp, usePatternMatcher); + assertExceptionMessage(sql, ex, expectedMessageRegExp, usePatternMatcher, exact); } } @@ -408,7 +408,7 @@ protected static void assertQueryReturnsEmptyResult(QueryRunner queryRunner, Ses } } - public static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex, boolean usePatternMatcher) + public static void assertExceptionMessage(String sql, Exception exception, @Language("RegExp") String regex, boolean usePatternMatcher, boolean exact) { if (usePatternMatcher) { Pattern p = Pattern.compile(regex, Pattern.MULTILINE); @@ -417,7 +417,7 @@ public static void assertExceptionMessage(String sql, Exception exception, @Lang } } else { - if (!nullToEmpty(exception.getMessage()).matches(regex)) { + if (!(exact ? nullToEmpty(exception.getMessage()).equals(regex) : nullToEmpty(exception.getMessage()).matches(regex))) { fail(format("Expected exception message '%s' to match '%s' for query: %s", exception.getMessage(), regex, sql), exception); } } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestExcludeColumnsFunction.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestExcludeColumnsFunction.java new file mode 100644 index 0000000000000..f55c66626cefe --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestExcludeColumnsFunction.java @@ -0,0 +1,200 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import org.testng.annotations.Test; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; + +public class TestExcludeColumnsFunction + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(testSessionBuilder().build()).build(); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + return queryRunner; + } + + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + + @Test + public void testExcludeColumnsFunction() + { + assertQuery("SELECT * FROM tpch.tiny.nation", + "SELECT nationkey, name, regionkey, comment FROM tpch.tiny.nation"); + + assertQuery("SELECT * " + + "FROM TABLE(exclude_columns( " + + " input => TABLE(tpch.tiny.nation)," + + " columns => DESCRIPTOR(comment)))", + "SELECT nationkey, name, regionkey FROM tpch.tiny.nation"); + + assertQuery("SELECT * " + + "FROM TABLE(exclude_columns( " + + " input => TABLE(tpch.tiny.nation), " + + " columns => DESCRIPTOR(regionkey, nationkey)))", + "SELECT name, comment FROM tpch.tiny.nation"); + } + + @Test + public void testInvalidArgument() + { + assertQueryFails("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => CAST(null AS DESCRIPTOR)))\n", + "COLUMNS descriptor is null"); + + assertQueryFailsExact("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR()))\n", + "line 4:21: Invalid descriptor argument COLUMNS. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'"); + + assertQueryFailsExact("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(foo, comment, bar)))\n", + "Excluded columns: [foo, bar] not present in the table"); + + assertQueryFails("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(nationkey bigint, comment)))\n", + "COLUMNS descriptor contains types"); + + assertQueryFails("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(nationkey, name, regionkey, comment)))\n", + "All columns are excluded"); + } + + @Test + public void testColumnResolution() + { + // excluded column names are matched case-insensitive + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1, 2, 3, 4, 5) t(a, B, \"c\", \"D\", e),\n" + + " columns => DESCRIPTOR(\"A\", \"b\", C, d)))\n", + "SELECT 5"); + } + + @Test + public void testReturnedColumnNames() + { + // the function preserves the incoming column names. (However, due to how the analyzer handles identifiers, these are not the canonical names according to the SQL identifier semantics.) + assertQuery("SELECT a, b, c, d\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1, 2, 3, 4, 5) t(a, B, \"c\", \"D\", e),\n" + + " columns => DESCRIPTOR(e)))\n", + "SELECT 1, 2, 3, 4"); + } + + @Test + public void testHiddenColumn() + { + assertQuery("SELECT row_number FROM tpch.tiny.region", + "SELECT * FROM UNNEST(sequence(0, 4))"); + + // the hidden column is not provided to the function + assertQueryFails("SELECT row_number\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment)))\n", + "line 1:8: Column 'row_number' cannot be resolved"); + + assertQueryFailsExact("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(row_number)))\n", + "Excluded columns: [row_number] not present in the table"); + } + + @Test + public void testAnonymousColumn() + { + // cannot exclude an unnamed columns. the unnamed columns are passed on unnamed. + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1 a, 2, 3 c, 4),\n" + + " columns => DESCRIPTOR(a, c)))\n", + "SELECT 2, 4"); + } + + @Test + public void testDuplicateExcludedColumn() + { + // duplicates in excluded column names are allowed + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment, name, comment)))\n", + "SELECT nationkey, regionkey FROM tpch.tiny.nation"); + } + + @Test + public void testDuplicateInputColumn() + { + // all input columns with given name are excluded + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(SELECT 1, 2, 3, 4, 5) t(a, b, c, a, b),\n" + + " columns => DESCRIPTOR(a, b)))\n", + "SELECT 3"); + } + + @Test + public void testFunctionResolution() + { + assertQuery("SELECT *\n" + + "FROM TABLE(system.builtin.exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment)))\n", + "SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.nation),\n" + + " columns => DESCRIPTOR(comment)))\n"); + } + + @Test + public void testBigInput() + { + assertQuery("SELECT *\n" + + "FROM TABLE(exclude_columns(\n" + + " input => TABLE(tpch.tiny.orders),\n" + + " columns => DESCRIPTOR(orderstatus, orderdate, orderpriority, clerk, shippriority, comment)))\n", + "SELECT orderkey, custkey, totalprice FROM tpch.tiny.orders"); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestSequenceFunction.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestSequenceFunction.java new file mode 100644 index 0000000000000..dd6bc1b8946aa --- /dev/null +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestSequenceFunction.java @@ -0,0 +1,291 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import org.testng.annotations.Test; + +import static com.facebook.presto.operator.table.Sequence.SequenceFunctionSplit.DEFAULT_SPLIT_SIZE; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static java.lang.String.format; + +public class TestSequenceFunction + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(testSessionBuilder().build()).build(); + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + return queryRunner; + } + + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + + @Test + public void testSequence() + { + assertQuery("SELECT * FROM TABLE(sequence(0, 8000, 3))", + "SELECT * FROM UNNEST(sequence(0, 8000, 3))"); + + assertQuery("SELECT * FROM TABLE(sequence(1, 10, 3))", + "VALUES BIGINT '1', 4, 7, 10"); + + assertQuery("SELECT * FROM TABLE(sequence(1, 10, 6))", + "VALUES BIGINT '1', 7"); + + assertQuery("SELECT * FROM TABLE(sequence(-1, -10, -3))", + "VALUES BIGINT '-1', -4, -7, -10"); + + assertQuery("SELECT * FROM TABLE(sequence(-1, -10, -6))", + "VALUES BIGINT '-1', -7"); + + assertQuery("SELECT * FROM TABLE(sequence(-5, 5, 3))", + "VALUES BIGINT '-5', -2, 1, 4"); + + assertQuery("SELECT * FROM TABLE(sequence(5, -5, -3))", + "VALUES BIGINT '5', 2, -1, -4"); + + assertQuery("SELECT * FROM TABLE(sequence(0, 10, 3))", + "VALUES BIGINT '0', 3, 6, 9"); + + assertQuery("SELECT * FROM TABLE(sequence(0, -10, -3))", + "VALUES BIGINT '0', -3, -6, -9"); + } + + @Test + public void testDefaultArguments() + { + assertQuery("SELECT * FROM TABLE(sequence(stop => 10))", + "SELECT * FROM UNNEST(sequence(0, 10, 1))"); + } + + @Test + public void testInvalidArgument() + { + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence( " + + " start => -5," + + " stop => 10," + + " step => -2))", + "Step must be positive for sequence [-5, 10]"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => 2))", + "Step must be negative for sequence [10, -5]"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => null," + + " stop => -5," + + " step => 2))", + "Start is null"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => null," + + " step => 2))", + "Stop is null"); + + assertQueryFailsExact("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => null))", + "Step is null"); + } + + @Test + public void testSingletonSequence() + { + assertQuery("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => 10," + + " step => 2))", + "VALUES BIGINT '10'"); + + assertQuery("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => 10," + + " step => -2))", + "VALUES BIGINT '10'"); + + assertQuery("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => 10," + + " step => 0))", + "VALUES BIGINT '10'"); + } + + @Test + public void testBigStep() + { + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => %s))", + Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1)), "VALUES BIGINT '10'"); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => %s))", + Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1) - 1), + "VALUES BIGINT '10'"); + + assertQuery(format("SELECT DISTINCT x - lag(x, 1) OVER(ORDER BY x DESC) \n" + + "FROM TABLE(sequence(\n" + + " start => %s,\n" + + " stop => BIGINT '%s',\n" + + " step => %s)) t(x)", + Long.MAX_VALUE, Long.MIN_VALUE, Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1) - 1), + format("VALUES (null), (%s)", Long.MIN_VALUE / (DEFAULT_SPLIT_SIZE - 1) - 1)); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => 10," + + " stop => -5," + + " step => BIGINT '%s'))", Long.MIN_VALUE), + "VALUES BIGINT '10'"); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => -5," + + " stop => 10," + + " step => %s))", Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1)), + "VALUES BIGINT '-5'"); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => -5," + + " stop => 10," + + " step => %s))", Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1) + 1), + "VALUES BIGINT '-5'"); + + assertQuery(format("SELECT DISTINCT x - lag(x, 1) OVER(ORDER BY x) " + + "FROM TABLE(sequence(" + + " start => BIGINT '%s'," + + " stop => %s," + + " step => %s)) t(x)", + Long.MIN_VALUE, Long.MAX_VALUE, Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1) + 1), + format("VALUES (null), (%s)", Long.MAX_VALUE / (DEFAULT_SPLIT_SIZE - 1) + 1)); + + assertQuery(format("SELECT * " + + "FROM TABLE(sequence(" + + " start => -5," + + " stop => 10," + + " step => %s))", Long.MAX_VALUE), + "VALUES BIGINT '-5'"); + } + + @Test + public void testMultipleSplits() + { + long sequenceLength = DEFAULT_SPLIT_SIZE * 10 + DEFAULT_SPLIT_SIZE / 2; + long start = 10; + long step = 5; + long stop = start + (sequenceLength - 1) * step; + assertQuery(format("SELECT count(x), count(DISTINCT x), min(x), max(x) " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT BIGINT '%s', BIGINT '%s', BIGINT '%s', BIGINT '%s'", sequenceLength, sequenceLength, start, stop)); + + sequenceLength = DEFAULT_SPLIT_SIZE * 4 + DEFAULT_SPLIT_SIZE / 2; + stop = start + (sequenceLength - 1) * step; + assertQuery(format("SELECT min(x), max(x) " + + "FROM TABLE(sequence(" + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT BIGINT '%s', BIGINT '%s'", start, stop)); + + step = -5; + stop = start + (sequenceLength - 1) * step; + assertQuery(format("SELECT max(x), min(x) " + + "FROM TABLE(sequence(" + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT BIGINT '%s', BIGINT '%s'", start, stop)); + } + + @Test + public void testEdgeValues() + { + long start = Long.MIN_VALUE + 15; + long stop = Long.MIN_VALUE + 3; + long step = -10; + assertQuery(format("SELECT * " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s))", start, stop, step), + format("VALUES (%s), (%s)", start, start + step)); + + start = Long.MIN_VALUE + 1 - (DEFAULT_SPLIT_SIZE - 1) * step; + stop = Long.MIN_VALUE + 1; + assertQuery(format("SELECT max(x), min(x) " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT %s, %s", start, Long.MIN_VALUE + 1)); + + start = Long.MAX_VALUE - 15; + stop = Long.MAX_VALUE - 3; + step = 10; + assertQuery(format("SELECT * " + + "FROM TABLE(sequence( " + + " start => %s," + + " stop => %s," + + " step => %s))", start, stop, step), + format("VALUES (%s), (%s)", start, start + step)); + + start = Long.MAX_VALUE - 1 - (DEFAULT_SPLIT_SIZE - 1) * step; + stop = Long.MAX_VALUE - 1; + assertQuery(format("SELECT min(x), max(x) " + + "FROM TABLE(sequence(" + + " start => %s," + + " stop => %s," + + " step => %s)) t(x)", start, stop, step), + format("SELECT %s, %s", start, Long.MAX_VALUE - 1)); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java b/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java index 067c750f09497..8f4bec9fe182a 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/TestTableFunctionInvocation.java @@ -16,12 +16,17 @@ import com.facebook.presto.connector.tvf.TestTVFConnectorColumnHandle; import com.facebook.presto.connector.tvf.TestTVFConnectorFactory; import com.facebook.presto.connector.tvf.TestTVFConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction; import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction.SimpleTableFunctionHandle; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.SchemaFunctionName; import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -32,7 +37,10 @@ import java.util.stream.IntStream; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.connector.tvf.TestTVFConnectorFactory.TestTVFConnector.TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.ConstantFunction.getConstantFunctionSplitSource; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static com.google.common.collect.ImmutableMap.toImmutableMap; public class TestTableFunctionInvocation @@ -52,6 +60,20 @@ protected QueryRunner createQueryRunner() .build(); } + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + DistributedQueryRunner result = DistributedQueryRunner.builder(testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .build()) + .build(); + result.installPlugin(new TpchPlugin()); + result.createCatalog("tpch", "tpch"); + return result; + } + @BeforeClass public void setUp() { @@ -63,16 +85,79 @@ public void setUp() .collect(toImmutableMap(column -> column, column -> new TestTVFConnectorColumnHandle(column, createUnboundedVarcharType()) {})); queryRunner.installPlugin(new TestTVFConnectorPlugin(TestTVFConnectorFactory.builder() - .withTableFunctions(ImmutableSet.of(new SimpleTableFunction())) + .withTableFunctions(ImmutableSet.of(new SimpleTableFunction(), + new TestingTableFunctions.IdentityFunction(), + new TestingTableFunctions.IdentityPassThroughFunction(), + new TestingTableFunctions.RepeatFunction(), + new TestingTableFunctions.EmptyOutputFunction(), + new TestingTableFunctions.EmptyOutputWithPassThroughFunction(), + new TestingTableFunctions.EmptySourceFunction(), + new TestingTableFunctions.TestInputsFunction(), + new TestingTableFunctions.PassThroughInputFunction(), + new TestingTableFunctions.TestInputFunction(), + new TestingTableFunctions.TestSingleInputRowSemanticsFunction(), + new TestingTableFunctions.ConstantFunction())) .withApplyTableFunction((session, handle) -> { if (handle instanceof SimpleTableFunctionHandle) { SimpleTableFunctionHandle functionHandle = (SimpleTableFunctionHandle) handle; return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Columns are missing")))); } - throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); - }).withGetColumnHandles(getColumnHandles) + return Optional.empty(); + }) + .withGetColumnHandles(getColumnHandles) + .withTableFunctionProcessorProvider( + connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof TestingTableFunctions.TestingTableFunctionHandle) { + switch (((TestingTableFunctions.TestingTableFunctionHandle) connectorTableFunctionHandle).getSchemaFunctionName().getFunctionName()) { + case "identity_function": + return new TestingTableFunctions.IdentityFunction.IdentityFunctionProcessorProvider(); + case "identity_pass_through_function": + return new TestingTableFunctions.IdentityPassThroughFunction.IdentityPassThroughFunctionProcessorProvider(); + case "empty_output": + return new TestingTableFunctions.EmptyOutputFunction.EmptyOutputProcessorProvider(); + case "empty_output_with_pass_through": + return new TestingTableFunctions.EmptyOutputWithPassThroughFunction.EmptyOutputWithPassThroughProcessorProvider(); + case "empty_source": + return new TestingTableFunctions.EmptySourceFunction.EmptySourceFunctionProcessorProvider(); + case "test_inputs_function": + return new TestingTableFunctions.TestInputsFunction.TestInputsFunctionProcessorProvider(); + case "pass_through": + return new TestingTableFunctions.PassThroughInputFunction.PassThroughInputProcessorProvider(); + case "test_input": + return new TestingTableFunctions.TestInputFunction.TestInputProcessorProvider(); + case "test_single_input_function": + return new TestingTableFunctions.TestSingleInputRowSemanticsFunction.TestSingleInputFunctionProcessorProvider(); + default: + throw new IllegalArgumentException("unexpected table function: " + ((TestingTableFunctions.TestingTableFunctionHandle) connectorTableFunctionHandle).getSchemaFunctionName()); + } + } + else if (connectorTableFunctionHandle instanceof TestingTableFunctions.RepeatFunction.RepeatFunctionHandle) { + return new TestingTableFunctions.RepeatFunction.RepeatFunctionProcessorProvider(); + } + else if (connectorTableFunctionHandle instanceof TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) { + return new TestingTableFunctions.ConstantFunction.ConstantFunctionProcessorProvider(); + } + return null; + }) + .withTableFunctionResolver(TestingTableFunctions.RepeatFunction.RepeatFunctionHandle.class) + .withTableFunctionResolver(TestingTableFunctions.TestingTableFunctionHandle.class) + .withTableFunctionResolver(TestingTableFunctions.ConstantFunction.ConstantFunctionHandle.class) + .withTableFunctionSplitResolver(TestingTableFunctions.ConstantFunction.ConstantFunctionSplit.class) + .withTableFunctionSplitSource( + connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) { + return getConstantFunctionSplitSource((TestingTableFunctions.ConstantFunction.ConstantFunctionHandle) connectorTableFunctionHandle); + } + else if (connectorTableFunctionHandle instanceof TestingTableFunctions.TestingTableFunctionHandle && ((TestingTableFunctions.TestingTableFunctionHandle) connectorTableFunctionHandle).getSchemaFunctionName().equals(new SchemaFunctionName("system", "empty_source"))) { + return new FixedSplitSource(ImmutableList.of(TEST_TVF_CONNECTOR_SPLIT)); + } + return null; + }) .build())); queryRunner.createCatalog(TESTING_CATALOG, "testTVF"); + + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); } @Test @@ -91,4 +176,465 @@ public void testNoArgumentsPassed() assertQuery("SELECT col FROM TABLE(system.simple_table_function())", "SELECT true WHERE false"); } + + @Test + public void testIdentityFunction() + { + assertQuery("SELECT b, a FROM TABLE(system.identity_function(input => TABLE(VALUES (1, 2), (3, 4), (5, 6)) T(a, b)))", + "VALUES (2, 1), (4, 3), (6, 5)"); + + assertQuery("SELECT b, a FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES (1, 2), (3, 4), (5, 6)) T(a, b)))", + "VALUES (2, 1), (4, 3), (6, 5)"); + + // null partitioning value + assertQuery("SELECT i.b, a FROM TABLE(system.identity_function(input => TABLE(VALUES ('x', 1), ('y', 2), ('z', null)) T(a, b) PARTITION BY b)) i", + "VALUES (1, 'x'), (2, 'y'), (null, 'z')"); + + assertQuery("SELECT b, a FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES ('x', 1), ('y', 2), ('z', null)) T(a, b) PARTITION BY b))", + "VALUES (1, 'x'), (2, 'y'), (null, 'z')"); + + // the identity_function copies all input columns and outputs them as proper columns. + // the table tpch.tiny.orders has a hidden column row_number, which is not exposed to the function. + assertQuery("SELECT * FROM TABLE(system.identity_function(input => TABLE(tpch.tiny.region)))", + "SELECT * FROM tpch.tiny.region"); + + // the identity_pass_through_function passes all input columns on output using the pass-through mechanism (as opposed to producing proper columns). + // the table tpch.tiny.orders has a hidden column row_number, which is exposed to the pass-through mechanism. + // the passed-through column row_number preserves its hidden property. + assertQuery("SELECT row_number, * FROM TABLE(system.identity_pass_through_function(input => TABLE(tpch.tiny.orders)))", + "SELECT row_number, * FROM tpch.tiny.orders"); + } + + @Test + public void testRepeatFunction() + { + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES (1, 2), (3, 4), (5, 6))))", + "VALUES (1, 2), (1, 2), (3, 4), (3, 4), (5, 6), (5, 6)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)), 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) PARTITION BY x,4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) ORDER BY y, 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(VALUES ('a', true), ('b', false)) t(x, y) PARTITION BY x ORDER BY y, 4))", + "VALUES ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false), ('a', true), ('b', false)"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part), 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) PARTITION BY type, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) ORDER BY size, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + + assertQuery("SELECT * FROM TABLE(system.repeat(TABLE(tpch.tiny.part) PARTITION BY type ORDER BY size, 3))", + "SELECT * FROM tpch.tiny.part UNION ALL TABLE tpch.tiny.part UNION ALL TABLE tpch.tiny.part"); + } + + @Test + public void testFunctionsReturningEmptyPages() + { + // the functions empty_output and empty_output_with_pass_through return an empty Page for each processed input Page. the argument has KEEP WHEN EMPTY property + + // non-empty input, no pass-trough columns + + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(tpch.tiny.orders)))", + "SELECT true WHERE false"); + + // non-empty input, pass-through partitioning column + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(tpch.tiny.orders) PARTITION BY orderstatus))", + "SELECT true, 'X' WHERE false"); + + // non-empty input, argument has pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(tpch.tiny.orders)))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // non-empty input, argument has pass-trough columns, partitioning column present + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(tpch.tiny.orders) PARTITION BY orderstatus))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // empty input, no pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", + "SELECT true WHERE false"); + + // empty input, pass-through partitioning column + assertQuery("SELECT * FROM TABLE(system.empty_output(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus))", + "SELECT true, 'X' WHERE false"); + + // empty input, argument has pass-trough columns + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // empty input, argument has pass-trough columns, partitioning column present + assertQuery("SELECT * FROM TABLE(system.empty_output_with_pass_through(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus)) ", + "SELECT true, * FROM tpch.tiny.orders WHERE false"); + + // function empty_source returns an empty Page for each Split it processes + assertQuery("SELECT * FROM TABLE(system.empty_source())", + "SELECT true WHERE false"); + } + + @Test + public void testInputPartitioning() + { + // table function test_inputs_function has four table arguments. input_1 has row semantics. input_2, input_3 and input_4 have set semantics. + // the function outputs one row per each tuple of partition it processes. The row includes a true value, and partitioning values. + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 9)))", + "VALUES (true, 4, 6), (true, 4, 7), (true, 5, 6), (true, 5, 7)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 9) t4(x4) PARTITION BY x4))", + "VALUES (true, 4, 6, 8), (true, 4, 6, 9), (true, 4, 7, 8), (true, 4, 7, 9), (true, 5, 6, 8), (true, 5, 6, 9), (true, 5, 7, 8), (true, 5, 7, 9)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 6, 7, 6) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 8, 8) t4(x4) PARTITION BY x4))", + "VALUES (true, 4, 6, 8), (true, 4, 7, 8), (true, 5, 6, 8), (true, 5, 7, 8)"); + + // null partitioning values + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, null)," + + "input_2 => TABLE(VALUES 2, null, 2, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 3, null, 3, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES null, null) t4(x4) PARTITION BY x4))", + "VALUES (true, 2, 3, null), (true, 2, null, null), (true, null, 3, null), (true, null, null, null)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 4, 5, 4, 5, 4)," + + "input_3 => TABLE(VALUES 6, 7, 6)," + + "input_4 => TABLE(VALUES 8, 9)))", + "VALUES true"); + + assertQuery("SELECT DISTINCT regionkey, nationkey FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(tpch.tiny.nation)," + + "input_2 => TABLE(tpch.tiny.nation) PARTITION BY regionkey ORDER BY name," + + "input_3 => TABLE(tpch.tiny.customer) PARTITION BY nationkey," + + "input_4 => TABLE(tpch.tiny.customer)))", + "SELECT DISTINCT n.regionkey, c.nationkey FROM tpch.tiny.nation n, tpch.tiny.customer c"); + } + + @Test + public void testEmptyPartitions() + { + // input_1 has row semantics, so it is prune when empty. input_2, input_3 and input_4 have set semantics, and are keep when empty by default + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false)," + + "input_3 => TABLE(SELECT 3 WHERE false)," + + "input_4 => TABLE(SELECT 4 WHERE false)))", + "VALUES true"); + + assertQueryReturnsEmptyResult("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(SELECT 1 WHERE false)," + + "input_2 => TABLE(VALUES 2)," + + "input_3 => TABLE(VALUES 3)," + + "input_4 => TABLE(VALUES 4)))"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(SELECT 4 WHERE false) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), CAST(null AS integer), CAST(null AS integer))"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 3, 4, 4) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 4, 4, 5, 5, 5, 5) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), 3, 4), (true, null, 4, 4), (true, null, 4, 5), (true, null, 3, 5)"); + + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 5) t4(x4) PARTITION BY x4))", + "VALUES (true, CAST(null AS integer), CAST(null AS integer), 4), (true, null, null, 5)"); + + assertQueryReturnsEmptyResult("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(SELECT 2 WHERE false) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(SELECT 3 WHERE false) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 4, 5) t4(x4) PARTITION BY x4))"); + } + + @Test + public void testCopartitioning() + { + // all tanbles are by default KEEP WHEN EMPTY. If there is no matching partition, it is null-completed + assertQuery("SELECT * FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4)))", + "VALUES (true, 1, null), (true, 2, 2), (true, null, 3)"); + + // partition `3` from input_4 is pruned because there is no matching partition in input_2 + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4)))", + "VALUES (true, 1, null), (true, 2, 2)"); + + // partition `1` from input_2 is pruned because there is no matching partition in input_4 + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4)))", + "VALUES (true, 2, 2), (true, null, 3)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4)))", + "VALUES (true, 2, 2)"); + + // null partitioning values + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null, 2, 2) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES null, 2, 2, 2, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4)))", + "VALUES (true, 1, null), (true, 2, 2), (true, null, null), (true, null, 3)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null, 2, 2) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 4, 5) t3(x3)," + + "input_4 => TABLE(VALUES null, 2, 2, 2, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4)))", + "VALUES (true, 2, 2), (true, null, null)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4, t3)))", + "VALUES (true, 1, null, null), (true, null, null, null), (true, null, 2, 2), (true, null, null, 3)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3 PRUNE WHEN EMPTY," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4, t3)))", + "VALUES (true, CAST(null AS integer), null, null), (true, null, 2, 2)"); + + assertQuery("SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 " + + "COPARTITION (t2, t4, t3)))", + "VALUES (true, 1, CAST(null AS integer), CAST(null AS integer)), (true, null, null, null)"); + + assertQueryReturnsEmptyResult( + "SELECT *" + + "FROM TABLE(system.test_inputs_function(" + + "input_1 => TABLE(VALUES 1, 2, 3)," + + "input_2 => TABLE(VALUES 1, 1, null, null) t2(x2) PARTITION BY x2 PRUNE WHEN EMPTY," + + "input_3 => TABLE(VALUES 2, 2, null) t3(x3) PARTITION BY x3," + + "input_4 => TABLE(VALUES 2, 3, 3) t4(x4) PARTITION BY x4 PRUNE WHEN EMPTY " + + "COPARTITION (t2, t4, t3)))"); + } + + @Test + public void testPassThroughWithEmptyPartitions() + { + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(VALUES (2, 'x'), (3, 'y')) t2(a2, b2) PARTITION BY a2 " + + "COPARTITION (t1, t2)))", + "VALUES (true, false, 1, 'a', null, null), (true, true, 2, 'b', 2, 'x'), (false, true, null, null, 3, 'y')"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(SELECT 2, 'x' WHERE false) t2(a2, b2) PARTITION BY a2 " + + "COPARTITION (t1, t2)))", + "VALUES (true, false, 1, 'a', CAST(null AS integer), CAST(null AS VARCHAR(1))), (true, false, 2, 'b', null, null)"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(" + + "TABLE(VALUES (1, 'a'), (2, 'b')) t1(a1, b1) PARTITION BY a1," + + "TABLE(SELECT 2, 'x' WHERE false) t2(a2, b2) PARTITION BY a2))", + "VALUES (true, false, 1, 'a', CAST(null AS integer), CAST(null AS VARCHAR(1))), (true, false, 2, 'b', null, null)"); + } + + @Test + public void testPassThroughWithEmptyInput() + { + assertQuery("SELECT * FROM TABLE(system.pass_through(TABLE(SELECT 1, 'x' WHERE false) t1(a1, b1) PARTITION BY a1, TABLE(SELECT 2, 'y' WHERE false) t2(a2, b2) PARTITION BY a2 COPARTITION (t1, t2)))", + "VALUES (false, false, CAST(null AS integer), CAST(null AS VARCHAR(1)), CAST(null AS integer), CAST(null AS VARCHAR(1)))"); + + assertQuery("SELECT * FROM TABLE(system.pass_through(TABLE(SELECT 1, 'x' WHERE false) t1(a1, b1) PARTITION BY a1, TABLE(SELECT 2, 'y' WHERE false) t2(a2, b2) PARTITION BY a2))", + "VALUES (false, false, CAST(null AS integer), CAST(null AS VARCHAR(1)), CAST(null AS integer), CAST(null AS VARCHAR(1)))"); + } + + @Test + public void testInput() + { + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(VALUES 1)))", "VALUES true"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(VALUES 1, 2, 3) t(a) PARTITION BY a))", + "VALUES true, true, true"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT 1 WHERE false)))", "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT 1 WHERE false) t(a) PARTITION BY a))", + "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT * FROM tpch.tiny.orders WHERE false)))", "VALUES false"); + + assertQuery("SELECT got_input FROM TABLE(system.test_input(TABLE(SELECT * FROM tpch.tiny.orders WHERE false) PARTITION BY orderstatus ORDER BY orderkey))", "VALUES false"); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + assertQuery("SELECT * FROM TABLE(system.test_single_input_function(TABLE(VALUES (true), (false), (true))))", "VALUES true"); + } + + @Test + public void testConstantFunction() + { + assertQuery("SELECT * FROM TABLE(system.constant(5))", "VALUES 5"); + + assertQuery("SELECT * FROM TABLE(system.constant(2, 10))", "VALUES (2), (2), (2), (2), (2), (2), (2), (2), (2), (2)"); + + assertQuery("SELECT * FROM TABLE(system.constant(null, 3))", "VALUES (CAST(null AS integer)), (null), (null)"); + + // value as constant expression + assertQuery("SELECT * FROM TABLE(system.constant(5 * 4, 3))", "VALUES (20), (20), (20)"); + + assertQueryFails("SELECT * FROM TABLE(system.constant(2147483648, 3))", "line 1:37: Cannot cast type bigint to integer"); + + assertQuery("SELECT count(*), count(DISTINCT constant_column), min(constant_column) FROM TABLE(system.constant(2, 1000000))", "VALUES (BIGINT '1000000', BIGINT '1', 2)"); + } + + @Test + public void testPruneAllColumns() + { + // function identity_pass_through_function has no proper outputs. It outputs input columns using the pass-through mechanism. + // in this case, no pass-through columns are referenced, so they are all pruned. The function effectively produces no columns. + assertQuery("SELECT 'a' FROM TABLE(system.identity_pass_through_function(input => TABLE(VALUES 1, 2, 3)))", + "VALUES 'a', 'a', 'a'"); + + // all pass-through columns are pruned. Also, the input is empty, and it has KEEP WHEN EMPTY property, so the function is executed on empty partition. + assertQuery("SELECT 'a' FROM TABLE(system.identity_pass_through_function(input => TABLE(SELECT 1 WHERE false)))", + "SELECT 'a' WHERE false"); + + // all pass-through columns are pruned. Also, the input is empty, and it has PRUNE WHEN EMPTY property, so the function is pruned out. + assertQuery("SELECT 'a' FROM TABLE(system.identity_pass_through_function(input => TABLE(SELECT 1 WHERE false) PRUNE WHEN EMPTY))", + "SELECT 'a' WHERE false"); + } + + @Test + public void testPrunePassThroughColumns() + { + // function pass_through has 2 proper columns, and it outputs all columns from both inputs using the pass-through mechanism. + // all columns are referenced + assertQuery("SELECT p1, p2, x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES (true, true, 3, 'c', 5, 'e')"); + + // all pass-through columns are referenced. Proper columns are not referenced, but they are not pruned. + assertQuery("SELECT x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES (3, 'c', 5, 'e')"); + + // some pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT x2, y2 " + + "FROM TABLE(system.pass_through(" + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES ('c', 'e')"); + + assertQuery("SELECT y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES (5, 'e')"); + + // no pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT 'x' " + + "FROM TABLE(system.pass_through( " + + " TABLE(VALUES (1, 'a'), (2, 'b'), (3, 'c')) t1(x1, x2)," + + " TABLE(VALUES (4, 'd'), (5, 'e')) t2(y1, y2))) t(p1, p2)", + "VALUES ('x')"); + } + + @Test + public void testPrunePassThroughColumnsWithEmptyInput() + { + // function pass_through has 2 proper columns, and it outputs all columns from both inputs using the pass-through mechanism. + // all columns are referenced + assertQuery("SELECT p1, p2, x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES (false, false, CAST(null AS integer), CAST(null AS varchar(1)), CAST(null AS integer), CAST(null AS varchar(1)))"); + + // all pass-through columns are referenced. Proper columns are not referenced, but they are not pruned. + assertQuery("SELECT x1, x2, y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2) ", + "VALUES (CAST(null AS integer), CAST(null AS varchar(1)), CAST(null AS integer), CAST(null AS varchar(1)))"); + + // some pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT x2, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES (CAST(null AS varchar(1)), CAST(null AS varchar(1)))"); + + assertQuery("SELECT y1, y2 " + + "FROM TABLE(system.pass_through( " + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES (CAST(null AS integer), CAST(null AS varchar(1)))"); + + // no pass-through columns are referenced. Unreferenced pass-through columns are pruned. + assertQuery("SELECT 'x' " + + "FROM TABLE(system.pass_through(" + + " TABLE(SELECT 1, 'a' WHERE FALSE) t1(x1, x2)," + + " TABLE(SELECT 4, 'd' WHERE FALSE) t2(y1, y2))) t(p1, p2)", + "VALUES ('x')"); + } }