Skip to content

[FLINK-37790][4/N] introduce model provider related interfaces #26577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.factories;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.catalog.Catalog;
import org.apache.flink.table.catalog.CatalogModel;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.catalog.ResolvedCatalogModel;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.ml.ModelProvider;

/**
* Creates a {@link ModelProvider} instance from a {@link CatalogModel} and additional context
* information.
*
* <p>See {@link Factory} for more information about the general design of a factory.
*/
@PublicEvolving
public interface ModelProviderFactory {

/** Create ModelProvider based on provider. */
ModelProvider createModelProvider(Context context);

/** Provides catalog and session information describing the model to be accessed. */
@PublicEvolving
interface Context {
/**
* Returns the identifier of the model in the {@link Catalog}.
*
* <p>This identifier describes the relationship between the model instance and the
* associated {@link Catalog} (if any).
*/
ObjectIdentifier getObjectIdentifier();

/**
* Returns the resolved model information received from the {@link Catalog} or persisted
* plan.
*
* <p>The {@link ResolvedCatalogModel} forwards the metadata from the catalog but offers a
* validated {@link ResolvedSchema}. The original metadata object is available via {@link
* ResolvedCatalogModel#getOrigin()}.
*
* <p>In most cases, a factory is interested in the following characteristics:
*
* <pre>{@code
* // get the physical input and output data type to initialize the connector
* context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType()
* context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType()
*
* // get configuration options
* context.getCatalogModel().getOptions()
* }</pre>
*
* <p>During a plan restore, usually the model information persisted in the plan is used to
* reconstruct the catalog model.
*/
ResolvedCatalogModel getCatalogModel();

/** Gives read-only access to the configuration of the current session. */
ReadableConfig getConfiguration();

/**
* Returns the class loader of the current session.
*
* <p>The class loader is in particular useful for discovering further (nested) factories.
*/
ClassLoader getClassLoader();

/** Whether the model is temporary. */
boolean isTemporary();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;

import java.util.Collection;
import java.util.concurrent.CompletableFuture;

/**
* A wrapper class of {@link AsyncTableFunction} for asynchronous model inference.
*
* <p>The output type of this table function is fixed as {@link RowData}.
*/
@PublicEvolving
public abstract class AsyncPredictFunction extends AsyncTableFunction<RowData> {

/**
* Asynchronously predict result based on input row.
*
* @param inputRow - A {@link RowData} that wraps input for predict function.
* @return A collection of all predicted results.
*/
public abstract CompletableFuture<Collection<RowData>> asyncPredict(RowData inputRow);

/** Invokes {@link #asyncPredict} and chains futures. */
public void eval(CompletableFuture<Collection<RowData>> future, Object... args) {
GenericRowData argsData = GenericRowData.of(args);
asyncPredict(argsData)
.whenComplete(
(result, exception) -> {
if (exception != null) {
future.completeExceptionally(
new TableException(
String.format(
"Failed to execute asynchronously prediction with input row %s.",
argsData),
exception));
return;
}
future.complete(result);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.util.FlinkRuntimeException;

import java.util.Collection;

/**
* A wrapper class of {@link TableFunction} for synchronous model inference.
*
* <p>The output type of this table function is fixed as {@link RowData}.
*/
@PublicEvolving
public abstract class PredictFunction extends TableFunction<RowData> {

/**
* Synchronously predict result based on input row.
*
* @param inputRow - A {@link RowData} that wraps input for predict function.
* @return A collection of predicted results.
*/
public abstract Collection<RowData> predict(RowData inputRow);

/** Invoke {@link #predict} and handle exceptions. */
public final void eval(Object... args) {
GenericRowData argsData = GenericRowData.of(args);
try {
Collection<RowData> results = predict(argsData);
if (results == null) {
return;
}
results.forEach(this::collect);
} catch (Exception e) {
throw new FlinkRuntimeException(
String.format("Failed to execute prediction with input row %s.", argsData), e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.ml;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.functions.AsyncPredictFunction;

/** A provider for creating asynchronous predict function. */
@PublicEvolving
public interface AsyncPredictRuntimeProvider extends ModelProvider {

/** Creates an {@link AsyncPredictFunction} instance. */
AsyncPredictFunction createAsyncPredictFunction(Context context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.ml;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.catalog.ResolvedCatalogModel;

/**
* Model Provider defines how to handle models from a particular provider. For example, how to do
* model inference for the models from a provider.
*
* <p>There could be but not limited to these types of model providers:
*
* <ul>
* <li>Public vendors. (e.g.OpenAI, Anthropic, Deepseek etc.)
* <li>Flink native. (model trained natively by Flink)
* </ul>
*/
@PublicEvolving
public interface ModelProvider {

/**
* Creates a copy of this instance during planning. The copy should be a deep copy of all
* mutable members.
*/
ModelProvider copy();

/** Context for creating runtime providers. */
@PublicEvolving
interface Context {

/** Resolved catalog model. */
ResolvedCatalogModel getCatalogModel();

/**
* Runtime config provided to provider. The config can be used by planner or model provider
* at runtime. For example, async options can be used by planner to choose async inference.
* Other config such as http timeout or retry can be used to configure model provider
* runtime http client when calling external model providers such as OpenAI.
*/
ReadableConfig runtimeConfig();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.ml;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.functions.PredictFunction;

/** A provider for creating synchronous predict function. */
@PublicEvolving
public interface PredictRuntimeProvider extends ModelProvider {

/** Creates an {@link PredictFunction} instance. */
PredictFunction createPredictFunction(Context context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
package org.apache.flink.table.module;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.catalog.Catalog;
import org.apache.flink.table.factories.DynamicTableSinkFactory;
import org.apache.flink.table.factories.DynamicTableSourceFactory;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.functions.FunctionDefinition;

import java.util.Collections;
Expand Down Expand Up @@ -111,5 +113,27 @@ default Optional<DynamicTableSinkFactory> getTableSinkFactory() {
return Optional.empty();
}

/**
* Returns a {@link ModelProviderFactory} for creating model providers.
*
* <p>A factory is determined with the following precedence rule:
*
* <ul>
* <li>1. Factory provided by the corresponding catalog of a persisted model. See {@link
* Catalog#getFactory()}
* <li>2. Factory provided by a module.
* <li>3. Factory discovered using Java SPI.
* </ul>
*
* <p>This will be called on loaded modules in the order in which they have been loaded. The
* first factory returned will be used.
*
* <p>This method can be useful to disable Java SPI completely or influence how temporary model
* providers should be created without a corresponding catalog.
*/
default Optional<ModelProviderFactory> getModelProviderFactory() {
return Optional.empty();
}

// user defined types, operators, rules, etc
}