From cc74ef89b8b99b18105659a8a7effbe1721a7e33 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:23:46 +0800 Subject: [PATCH 1/4] Refactor export_lib and add export_onnx Add tf2onnx requirements --- .kokoro/github/ubuntu/gpu/build.sh | 1 - keras/api/_tf_keras/keras/export/__init__.py | 2 +- keras/api/_tf_keras/keras/layers/__init__.py | 2 +- keras/api/export/__init__.py | 2 +- keras/api/layers/__init__.py | 2 +- keras/src/backend/torch/export.py | 24 +- keras/src/export/__init__.py | 5 +- keras/src/export/export_utils.py | 105 +++++++++ keras/src/export/onnx.py | 162 +++++++++++++ keras/src/export/onnx_test.py | 216 +++++++++++++++++ .../export/{export_lib.py => saved_model.py} | 222 +----------------- ...export_lib_test.py => saved_model_test.py} | 209 +++-------------- keras/src/export/tfsm_layer.py | 139 +++++++++++ keras/src/export/tfsm_layer_test.py | 145 ++++++++++++ keras/src/layers/core/dense_test.py | 6 +- keras/src/layers/core/einsum_dense_test.py | 6 +- keras/src/layers/core/embedding_test.py | 4 +- keras/src/layers/layer.py | 16 +- keras/src/models/model.py | 47 +++- keras/src/models/model_test.py | 69 ++++-- keras/src/utils/module_utils.py | 1 + requirements-jax-cuda.txt | 1 + requirements-tensorflow-cuda.txt | 1 + requirements-torch-cuda.txt | 1 + requirements.txt | 1 + 25 files changed, 943 insertions(+), 446 deletions(-) create mode 100644 keras/src/export/export_utils.py create mode 100644 keras/src/export/onnx.py create mode 100644 keras/src/export/onnx_test.py rename keras/src/export/{export_lib.py => saved_model.py} (75%) rename keras/src/export/{export_lib_test.py => saved_model_test.py} (83%) create mode 100644 keras/src/export/tfsm_layer.py create mode 100644 keras/src/export/tfsm_layer_test.py diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index a70f28a062a0..9164cee023dd 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -72,7 +72,6 @@ then # Raise error if GPU is not detected. python3 -c 'import torch;assert torch.cuda.is_available()' - # TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH pytest keras --ignore keras/src/applications \ --cov=keras \ --cov-config=pyproject.toml diff --git a/keras/api/_tf_keras/keras/export/__init__.py b/keras/api/_tf_keras/keras/export/__init__.py index 68fa60293961..49f7a66972be 100644 --- a/keras/api/_tf_keras/keras/export/__init__.py +++ b/keras/api/_tf_keras/keras/export/__init__.py @@ -4,4 +4,4 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import ExportArchive +from keras.src.export.saved_model import ExportArchive diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 82e8d0da9d15..4f13a5961303 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -4,7 +4,7 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import TFSMLayer +from keras.src.export.tfsm_layer import TFSMLayer from keras.src.layers import deserialize from keras.src.layers import serialize from keras.src.layers.activations.activation import Activation diff --git a/keras/api/export/__init__.py b/keras/api/export/__init__.py index 68fa60293961..49f7a66972be 100644 --- a/keras/api/export/__init__.py +++ b/keras/api/export/__init__.py @@ -4,4 +4,4 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import ExportArchive +from keras.src.export.saved_model import ExportArchive diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index a70561253b08..a4aaf7c99174 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -4,7 +4,7 @@ since your modifications would be overwritten. """ -from keras.src.export.export_lib import TFSMLayer +from keras.src.export.tfsm_layer import TFSMLayer from keras.src.layers import deserialize from keras.src.layers import serialize from keras.src.layers.activations.activation import Activation diff --git a/keras/src/backend/torch/export.py b/keras/src/backend/torch/export.py index 6f05a8257251..7de5653e9fb5 100644 --- a/keras/src/backend/torch/export.py +++ b/keras/src/backend/torch/export.py @@ -3,9 +3,8 @@ import torch -from keras.src import backend -from keras.src import ops from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor from keras.src.utils.module_utils import tensorflow as tf from keras.src.utils.module_utils import torch_xla @@ -36,23 +35,10 @@ def track_and_add_endpoint(self, name, resource, input_signature, **kwargs): f"Received: resource={resource} (of type {type(resource)})" ) - def _check_input_signature(input_spec): - for s in tree.flatten(input_spec.shape): - if s is None: - raise ValueError( - "The shape in the `input_spec` must be fully " - f"specified. Received: input_spec={input_spec}" - ) - - def _to_torch_tensor(x, replace_none_number=1): - shape = backend.standardize_shape(x.shape) - shape = tuple( - s if s is not None else replace_none_number for s in shape - ) - return ops.ones(shape, x.dtype) - - tree.map_structure(_check_input_signature, input_signature) - sample_inputs = tree.map_structure(_to_torch_tensor, input_signature) + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) sample_inputs = tuple(sample_inputs) # Ref: torch_xla.tf_saved_model_integration diff --git a/keras/src/export/__init__.py b/keras/src/export/__init__.py index d9de43f685a0..a51487812ea0 100644 --- a/keras/src/export/__init__.py +++ b/keras/src/export/__init__.py @@ -1 +1,4 @@ -from keras.src.export.export_lib import ExportArchive +from keras.src.export.onnx import export_onnx +from keras.src.export.saved_model import ExportArchive +from keras.src.export.saved_model import export_saved_model +from keras.src.export.tfsm_layer import TFSMLayer diff --git a/keras/src/export/export_utils.py b/keras/src/export/export_utils.py new file mode 100644 index 000000000000..bfb66180f4b9 --- /dev/null +++ b/keras/src/export/export_utils.py @@ -0,0 +1,105 @@ +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import tree +from keras.src.utils.module_utils import tensorflow as tf + + +def get_input_signature(model): + if not isinstance(model, models.Model): + raise TypeError( + "The model must be a `keras.Model`. " + f"Received: model={model} of the type {type(model)}" + ) + if not model.built: + raise ValueError( + "The model provided has not yet been built. It must be built " + "before export." + ) + if isinstance(model, (models.Functional, models.Sequential)): + input_signature = tree.map_structure(make_input_spec, model.inputs) + if isinstance(input_signature, list) and len(input_signature) > 1: + input_signature = [input_signature] + else: + input_signature = _infer_input_signature_from_model(model) + if not input_signature or not model._called: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + return input_signature + + +def _infer_input_signature_from_model(model): + shapes_dict = getattr(model, "_build_shapes_dict", None) + if not shapes_dict: + return None + + def _make_input_spec(structure): + # We need to turn wrapper structures like TrackingDict or _DictWrapper + # into plain Python structures because they don't work with jax2tf/JAX. + if isinstance(structure, dict): + return {k: _make_input_spec(v) for k, v in structure.items()} + elif isinstance(structure, tuple): + if all(isinstance(d, (int, type(None))) for d in structure): + return layers.InputSpec( + shape=(None,) + structure[1:], dtype=model.input_dtype + ) + return tuple(_make_input_spec(v) for v in structure) + elif isinstance(structure, list): + if all(isinstance(d, (int, type(None))) for d in structure): + return layers.InputSpec( + shape=[None] + structure[1:], dtype=model.input_dtype + ) + return [_make_input_spec(v) for v in structure] + else: + raise ValueError( + f"Unsupported type {type(structure)} for {structure}" + ) + + return [_make_input_spec(value) for value in shapes_dict.values()] + + +def make_input_spec(x): + if isinstance(x, layers.InputSpec): + if x.shape is None or x.dtype is None: + raise ValueError( + "The `shape` and `dtype` must be provided. " f"Received: x={x}" + ) + input_spec = x + elif isinstance(x, backend.KerasTensor): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name) + elif backend.is_tensor(x): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None) + else: + raise TypeError( + f"Unsupported x={x} of the type ({type(x)}). Supported types are: " + "`keras.InputSpec`, `keras.KerasTensor` and backend tensor." + ) + return input_spec + + +def make_tf_tensor_spec(x): + if isinstance(x, tf.TensorSpec): + tensor_spec = x + else: + input_spec = make_input_spec(x) + tensor_spec = tf.TensorSpec( + input_spec.shape, dtype=input_spec.dtype, name=input_spec.name + ) + return tensor_spec + + +def convert_spec_to_tensor(spec, replace_none_number=None): + shape = backend.standardize_shape(spec.shape) + if replace_none_number is not None: + replace_none_number = int(replace_none_number) + shape = tuple( + s if s is not None else replace_none_number for s in shape + ) + return ops.ones(shape, spec.dtype) diff --git a/keras/src/export/onnx.py b/keras/src/export/onnx.py new file mode 100644 index 000000000000..acca68bdcc32 --- /dev/null +++ b/keras/src/export/onnx.py @@ -0,0 +1,162 @@ +import pathlib +import tempfile + +from keras.src import backend +from keras.src import tree +from keras.src.export.export_utils import convert_spec_to_tensor +from keras.src.export.export_utils import get_input_signature +from keras.src.export.saved_model import export_saved_model +from keras.src.utils.module_utils import tensorflow as tf + + +def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs): + """Export the model as a ONNX artifact for inference. + + This method lets you export a model to a lightweight ONNX artifact + that contains the model's forward pass only (its `call()` method) + and can be served via e.g. ONNX Runtime. + + The original code of the model (including any custom layers you may + have used) is *no longer* necessary to reload the artifact -- it is + entirely standalone. + + Args: + filepath: `str` or `pathlib.Path` object. The path to save the artifact. + verbose: `bool`. Whether to print a message during export. Defaults to + True`. + input_signature: Optional. Specifies the shape and dtype of the model + inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor. If not provided, it will + be automatically computed. Defaults to `None`. + **kwargs: Additional keyword arguments. + + **Note:** This feature is currently supported only with TensorFlow, JAX and + Torch backends. + + **Note:** The dtype policy must be "float32" for the model. You can further + optimize the ONNX artifact using the ONNX toolkit. Learn more here: + https://onnxruntime.ai/docs/performance/. + + **Note:** The dynamic shape feature is not yet supported with Torch + backend. As a result, you must fully define the shapes of the inputs using + `input_signature`. If `input_signature` is not provided, all instances of + `None` (such as the batch size) will be replaced with `1`. + + Example: + + ```python + # Export the model as a ONNX artifact + model.export("path/to/location", format="onnx") + + # Load the artifact in a different process/environment + ort_session = onnxruntime.InferenceSession("path/to/location") + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) + } + predictions = ort_session.run(None, ort_inputs) + ``` + """ + if input_signature is None: + input_signature = get_input_signature(model) + if not input_signature or not model._called: + raise ValueError( + "The model provided has never called. " + "It must be called at least once before export." + ) + + if backend.backend() in ("tensorflow", "jax"): + working_dir = pathlib.Path(filepath).parent + with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir: + if backend.backend() == "jax": + kwargs = _check_jax_kwargs(kwargs) + export_saved_model( + model, + temp_dir, + verbose, + input_signature, + **kwargs, + ) + saved_model_to_onnx(temp_dir, filepath, model.name) + + elif backend.backend() == "torch": + import torch + + sample_inputs = tree.map_structure( + lambda x: convert_spec_to_tensor(x, replace_none_number=1), + input_signature, + ) + sample_inputs = tuple(sample_inputs) + # TODO: Make dict model exportable. + if any(isinstance(x, dict) for x in sample_inputs): + raise ValueError( + "Currently, `export_onnx` in the torch backend doesn't support " + "dictionaries as inputs." + ) + + # Convert to ONNX using TorchScript-based ONNX Exporter. + # TODO: Use TorchDynamo-based ONNX Exporter once + # `torch.onnx.dynamo_export()` supports Keras models. + torch.onnx.export(model, sample_inputs, filepath, verbose=verbose) + else: + raise NotImplementedError( + "`export_onnx` is only compatible with TensorFlow, JAX and " + "Torch backends." + ) + + +def _check_jax_kwargs(kwargs): + kwargs = kwargs.copy() + if "is_static" not in kwargs: + kwargs["is_static"] = True + if "jax2tf_kwargs" not in kwargs: + # TODO: These options will be deprecated in JAX. We need to + # find another way to export ONNX. + kwargs["jax2tf_kwargs"] = { + "enable_xla": False, + "native_serialization": False, + } + if kwargs["is_static"] is not True: + raise ValueError( + "`is_static` must be `True` in `kwargs` when using the jax " + "backend." + ) + if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: + raise ValueError( + "`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " + "when using the jax backend." + ) + if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: + raise ValueError( + "`native_serialization` must be `False` in " + "`kwargs['jax2tf_kwargs']` when using the jax backend." + ) + return kwargs + + +def saved_model_to_onnx(saved_model_dir, filepath, name): + from keras.src.utils.module_utils import tf2onnx + + # Convert to ONNX using `tf2onnx` library. + (graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = ( + tf2onnx.tf_loader.from_saved_model( + saved_model_dir, + None, + None, + return_initialized_tables=True, + return_tensors_to_rename=True, + ) + ) + + with tf.device("/cpu:0"): + _ = tf2onnx.convert._convert_common( + graph_def, + name=name, + target=[], + custom_op_handlers={}, + extra_opset=[], + input_names=inputs, + output_names=outputs, + tensors_to_rename=tensors_to_rename, + initialized_tables=initialized_tables, + output_path=filepath, + ) diff --git a/keras/src/export/onnx_test.py b/keras/src/export/onnx_test.py new file mode 100644 index 000000000000..2df09e3730fa --- /dev/null +++ b/keras/src/export/onnx_test.py @@ -0,0 +1,216 @@ +"""Tests for ONNX exporting utilities.""" + +import os + +import numpy as np +import onnxruntime +import pytest +from absl.testing import parameterized + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import ops +from keras.src import testing +from keras.src import tree +from keras.src.export import onnx +from keras.src.saving import saving_lib +from keras.src.testing.test_utils import named_product + + +class CustomModel(models.Model): + def __init__(self, layer_list): + super().__init__() + self.layer_list = layer_list + + def call(self, input): + output = input + for layer in self.layer_list: + output = layer(output) + return output + + +def get_model(type="sequential", input_shape=(10,), layer_list=None): + layer_list = layer_list or [ + layers.Dense(10, activation="relu"), + layers.BatchNormalization(), + layers.Dense(1, activation="sigmoid"), + ] + if type == "sequential": + return models.Sequential(layer_list) + elif type == "functional": + input = output = tree.map_shape_structure(layers.Input, input_shape) + for layer in layer_list: + output = layer(output) + return models.Model(inputs=input, outputs=output) + elif type == "subclass": + return CustomModel(layer_list) + + +@pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax", "torch"), + reason=( + "`export_onnx` only currently supports the tensorflow, jax and torch " + "backends." + ), +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +class ExportONNXTest(testing.TestCase): + @parameterized.named_parameters( + named_product(model_type=["sequential", "functional", "subclass"]) + ) + def test_standard_model_export(self, model_type): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input]) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [np.concatenate([ref_input, ref_input], axis=0)], + ) + } + ort_session.run(None, ort_inputs) + + @parameterized.named_parameters( + named_product(struct_type=["tuple", "array", "dict"]) + ) + def test_model_with_input_structure(self, struct_type): + if backend.backend() == "torch" and struct_type == "dict": + self.skipTest("The torch backend doesn't support the dict model.") + + class TupleModel(models.Model): + def call(self, inputs): + x, y = inputs + return ops.add(x, y) + + class ArrayModel(models.Model): + def call(self, inputs): + x = inputs[0] + y = inputs[1] + return ops.add(x, y) + + class DictModel(models.Model): + def call(self, inputs): + x = inputs["x"] + y = inputs["y"] + return ops.add(x, y) + + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") + if struct_type == "tuple": + model = TupleModel() + ref_input = (ref_input, ref_input * 2) + elif struct_type == "array": + model = ArrayModel() + ref_input = [ref_input, ref_input * 2] + elif struct_type == "dict": + model = DictModel() + ref_input = {"x": ref_input, "y": ref_input * 2} + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + if isinstance(ref_input, dict): + ort_inputs = { + k.name: v + for k, v in zip(ort_session.get_inputs(), ref_input.values()) + } + else: + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), ref_input) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + + # Test with keras.saving_lib + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model2") + onnx.export_onnx(revived_model, temp_filepath) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + bigger_ref_input = tree.map_structure( + lambda x: np.concatenate([x, x], axis=0), ref_input + ) + if isinstance(bigger_ref_input, dict): + bigger_ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), bigger_ref_input.values() + ) + } + else: + bigger_ort_inputs = { + k.name: v + for k, v in zip(ort_session.get_inputs(), bigger_ref_input) + } + ort_session.run(None, bigger_ort_inputs) + + def test_model_with_multiple_inputs(self): + class TwoInputsModel(models.Model): + def call(self, x, y): + return x + y + + def build(self, y_shape, x_shape): + self.built = True + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = TwoInputsModel() + batch_size = 3 if backend.backend() != "torch" else 1 + ref_input_x = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") + ref_output = model(ref_input_x, ref_input_y) + + onnx.export_onnx(model, temp_filepath) + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), [ref_input_x, ref_input_y] + ) + } + self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0]) + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [ + np.concatenate([ref_input_x, ref_input_x], axis=0), + np.concatenate([ref_input_y, ref_input_y], axis=0), + ], + ) + } + ort_session.run(None, ort_inputs) diff --git a/keras/src/export/export_lib.py b/keras/src/export/saved_model.py similarity index 75% rename from keras/src/export/export_lib.py rename to keras/src/export/saved_model.py index a58e60c1bd3d..bc194dc67426 100644 --- a/keras/src/export/export_lib.py +++ b/keras/src/export/saved_model.py @@ -1,11 +1,11 @@ -"""Library for exporting inference-only Keras models/layers.""" +"""Library for exporting SavedModel for Keras models/layers.""" from keras.src import backend from keras.src import layers from keras.src import tree from keras.src.api_export import keras_export -from keras.src.models import Functional -from keras.src.models import Sequential +from keras.src.export.export_utils import get_input_signature +from keras.src.export.export_utils import make_tf_tensor_spec from keras.src.utils import io_utils from keras.src.utils.module_utils import tensorflow as tf @@ -326,7 +326,9 @@ def serving_fn(x): self._endpoint_names.append(name) return decorated_fn - input_signature = tree.map_structure(_make_tensor_spec, input_signature) + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) decorated_fn = BackendExportArchive.add_endpoint( self, name, fn, input_signature, **kwargs ) @@ -383,7 +385,9 @@ def track_and_add_endpoint(self, name, resource, input_signature, **kwargs): f"the jax backend. Current backend: {backend.backend()}" ) - input_signature = tree.map_structure(_make_tensor_spec, input_signature) + input_signature = tree.map_structure( + make_tf_tensor_spec, input_signature + ) if not hasattr(BackendExportArchive, "track_and_add_endpoint"): # Default behavior. @@ -616,24 +620,7 @@ def export_saved_model( """ export_archive = ExportArchive() if input_signature is None: - if not model.built: - raise ValueError( - "The layer provided has not yet been built. " - "It must be built before export." - ) - if isinstance(model, (Functional, Sequential)): - input_signature = tree.map_structure( - _make_tensor_spec, model.inputs - ) - if isinstance(input_signature, list) and len(input_signature) > 1: - input_signature = [input_signature] - else: - input_signature = _get_input_signature(model) - if not input_signature or not model._called: - raise ValueError( - "The model provided has never called. " - "It must be called at least once before export." - ) + input_signature = get_input_signature(model) export_archive.track_and_add_endpoint( "serve", model, input_signature, **kwargs @@ -641,195 +628,6 @@ def export_saved_model( export_archive.write_out(filepath, verbose=verbose) -def _get_input_signature(model): - shapes_dict = getattr(model, "_build_shapes_dict", None) - if not shapes_dict: - return None - - def make_tensor_spec(structure): - # We need to turn wrapper structures like TrackingDict or _DictWrapper - # into plain Python structures because they don't work with jax2tf/JAX. - if isinstance(structure, dict): - return {k: make_tensor_spec(v) for k, v in structure.items()} - elif isinstance(structure, tuple): - if all(isinstance(d, (int, type(None))) for d in structure): - return tf.TensorSpec( - shape=(None,) + structure[1:], dtype=model.input_dtype - ) - return tuple(make_tensor_spec(v) for v in structure) - elif isinstance(structure, list): - if all(isinstance(d, (int, type(None))) for d in structure): - return tf.TensorSpec( - shape=[None] + structure[1:], dtype=model.input_dtype - ) - return [make_tensor_spec(v) for v in structure] - else: - raise ValueError( - f"Unsupported type {type(structure)} for {structure}" - ) - - return [make_tensor_spec(value) for value in shapes_dict.values()] - - -@keras_export("keras.layers.TFSMLayer") -class TFSMLayer(layers.Layer): - """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. - - Arguments: - filepath: `str` or `pathlib.Path` object. The path to the SavedModel. - call_endpoint: Name of the endpoint to use as the `call()` method - of the reloaded layer. If the SavedModel was created - via `model.export()`, - then the default endpoint name is `'serve'`. In other cases - it may be named `'serving_default'`. - - Example: - - ```python - model.export("path/to/artifact") - reloaded_layer = TFSMLayer("path/to/artifact") - outputs = reloaded_layer(inputs) - ``` - - The reloaded object can be used like a regular Keras layer, and supports - training/fine-tuning of its trainable weights. Note that the reloaded - object retains none of the internal structure or custom methods of the - original object -- it's a brand new layer created around the saved - function. - - **Limitations:** - - * Only call endpoints with a single `inputs` tensor argument - (which may optionally be a dict/tuple/list of tensors) are supported. - For endpoints with multiple separate input tensor arguments, consider - subclassing `TFSMLayer` and implementing a `call()` method with a - custom signature. - * If you need training-time behavior to differ from inference-time behavior - (i.e. if you need the reloaded object to support a `training=True` argument - in `__call__()`), make sure that the training-time call function is - saved as a standalone endpoint in the artifact, and provide its name - to the `TFSMLayer` via the `call_training_endpoint` argument. - """ - - def __init__( - self, - filepath, - call_endpoint="serve", - call_training_endpoint=None, - trainable=True, - name=None, - dtype=None, - ): - if backend.backend() != "tensorflow": - raise NotImplementedError( - "The TFSMLayer is only currently supported with the " - "TensorFlow backend." - ) - - # Initialize an empty layer, then add_weight() etc. as needed. - super().__init__(trainable=trainable, name=name, dtype=dtype) - - self._reloaded_obj = tf.saved_model.load(filepath) - - self.filepath = filepath - self.call_endpoint = call_endpoint - self.call_training_endpoint = call_training_endpoint - - # Resolve the call function. - if hasattr(self._reloaded_obj, call_endpoint): - # Case 1: it's set as an attribute. - self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint) - elif call_endpoint in self._reloaded_obj.signatures: - # Case 2: it's listed in the `signatures` field. - self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] - else: - raise ValueError( - f"The endpoint '{call_endpoint}' " - "is neither an attribute of the reloaded SavedModel, " - "nor an entry in the `signatures` field of " - "the reloaded SavedModel. Select another endpoint via " - "the `call_endpoint` argument. Available endpoints for " - "this SavedModel: " - f"{list(self._reloaded_obj.signatures.keys())}" - ) - - # Resolving the training function. - if call_training_endpoint: - if hasattr(self._reloaded_obj, call_training_endpoint): - self.call_training_endpoint_fn = getattr( - self._reloaded_obj, call_training_endpoint - ) - elif call_training_endpoint in self._reloaded_obj.signatures: - self.call_training_endpoint_fn = self._reloaded_obj.signatures[ - call_training_endpoint - ] - else: - raise ValueError( - f"The endpoint '{call_training_endpoint}' " - "is neither an attribute of the reloaded SavedModel, " - "nor an entry in the `signatures` field of " - "the reloaded SavedModel. Available endpoints for " - "this SavedModel: " - f"{list(self._reloaded_obj.signatures.keys())}" - ) - - # Add trainable and non-trainable weights from the call_endpoint_fn. - all_fns = [self.call_endpoint_fn] - if call_training_endpoint: - all_fns.append(self.call_training_endpoint_fn) - tvs, ntvs = _list_variables_used_by_fns(all_fns) - for v in tvs: - self._add_existing_weight(v) - for v in ntvs: - self._add_existing_weight(v) - self.built = True - - def _add_existing_weight(self, weight): - """Tracks an existing weight.""" - self._track_variable(weight) - - def call(self, inputs, training=False, **kwargs): - if training: - if self.call_training_endpoint: - return self.call_training_endpoint_fn(inputs, **kwargs) - return self.call_endpoint_fn(inputs, **kwargs) - - def get_config(self): - base_config = super().get_config() - config = { - # Note: this is not intended to be portable. - "filepath": self.filepath, - "call_endpoint": self.call_endpoint, - "call_training_endpoint": self.call_training_endpoint, - } - return {**base_config, **config} - - -def _make_tensor_spec(x): - if isinstance(x, layers.InputSpec): - if x.shape is None or x.dtype is None: - raise ValueError( - "The `shape` and `dtype` must be provided. " f"Received: x={x}" - ) - tensor_spec = tf.TensorSpec(x.shape, dtype=x.dtype, name=x.name) - elif isinstance(x, tf.TensorSpec): - tensor_spec = x - elif isinstance(x, backend.KerasTensor): - shape = (None,) + backend.standardize_shape(x.shape)[1:] - tensor_spec = tf.TensorSpec(shape, dtype=x.dtype, name=x.name) - elif backend.is_tensor(x): - shape = (None,) + backend.standardize_shape(x.shape)[1:] - dtype = backend.standardize_dtype(x.dtype) - tensor_spec = tf.TensorSpec(shape, dtype=dtype, name=None) - else: - raise TypeError( - f"Unsupported x={x} of the type ({type(x)}). Supported types are: " - "`keras.InputSpec`, `tf.TensorSpec`, `keras.KerasTensor` and " - "backend tensor." - ) - return tensor_spec - - def _print_signature(fn, name, verbose=True): concrete_fn = fn._list_all_concrete_functions()[0] pprinted_signature = concrete_fn.pretty_printed_signature(verbose=verbose) diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/saved_model_test.py similarity index 83% rename from keras/src/export/export_lib_test.py rename to keras/src/export/saved_model_test.py index 9ee2d6fc5125..c5ad6c58690c 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/saved_model_test.py @@ -1,4 +1,4 @@ -"""Tests for inference-only model/layer exporting utilities.""" +"""Tests for SavedModel exporting utilities.""" import os @@ -14,8 +14,7 @@ from keras.src import random from keras.src import testing from keras.src import tree -from keras.src import utils -from keras.src.export import export_lib +from keras.src.export import saved_model from keras.src.saving import saving_lib from keras.src.testing.test_utils import named_product @@ -71,7 +70,7 @@ def test_standard_model_export(self, model_type): ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size @@ -106,7 +105,7 @@ def call(self, inputs): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape) # Test with a different batch size @@ -142,7 +141,7 @@ def call(self, inputs): model = get_model(model_type, layer_list=[StateLayer()]) model(tf.random.normal((3, 10))) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) # The non-trainable counter is expected to increment @@ -164,7 +163,7 @@ def test_model_with_tf_data_layer(self, model_type): ref_input = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size @@ -209,7 +208,7 @@ def call(self, inputs): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) @@ -227,7 +226,7 @@ def call(self, inputs): }, ) self.assertAllClose(ref_output, revived_model(ref_input)) - export_lib.export_saved_model(revived_model, self.get_temp_dir()) + saved_model.export_saved_model(revived_model, self.get_temp_dir()) # Test with a different batch size if backend.backend() == "torch": @@ -253,7 +252,7 @@ def build(self, y_shape, x_shape): ref_input_y = np.random.normal(size=(batch_size, 10)).astype("float32") ref_output = model(ref_input_x, ref_input_y) - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose( ref_output, revived_model.serve(ref_input_x, ref_input_y) @@ -290,7 +289,7 @@ def test_input_signature(self, model_type, input_signature): input_signature = (ref_input,) else: input_signature = (input_signature,) - export_lib.export_saved_model( + saved_model.export_saved_model( model, temp_filepath, input_signature=input_signature ) revived_model = tf.saved_model.load(temp_filepath) @@ -303,7 +302,7 @@ def test_input_signature_error(self): model = get_model("functional") with self.assertRaisesRegex(TypeError, "Unsupported x="): input_signature = (123,) - export_lib.export_saved_model( + saved_model.export_saved_model( model, temp_filepath, input_signature=input_signature ) @@ -327,7 +326,7 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): ref_input = ops.random.uniform((3, 10)) ref_output = model(ref_input) - export_lib.export_saved_model( + saved_model.export_saved_model( model, temp_filepath, is_static=is_static, @@ -362,13 +361,13 @@ def test_low_level_model_export(self, model_type): ref_output = model(ref_input) # Test variable tracking - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) self.assertLen(export_archive.variables, 8) self.assertLen(export_archive.trainable_variables, 6) self.assertLen(export_archive.non_trainable_variables, 2) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -388,7 +387,7 @@ def test_low_level_model_export_with_alias(self): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) fn = export_archive.add_endpoint( "call", @@ -429,7 +428,7 @@ def call(self, inputs): ref_input = [tf.random.normal((3, 8)), tf.random.normal((3, 6))] ref_output = model(ref_input) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -460,7 +459,7 @@ def test_low_level_model_export_with_jax2tf_kwargs(self): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -509,7 +508,7 @@ def call(self, inputs): # This will fail because the polymorphic_shapes that is # automatically generated will not account for the fact that # dynamic dimensions 1 and 2 must have the same value. - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -519,7 +518,7 @@ def call(self, inputs): ) export_archive.write_out(temp_filepath) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -543,7 +542,7 @@ def test_endpoint_registration_tf_function(self): ref_output = model(ref_input) # Test variable tracking - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) self.assertLen(export_archive.variables, 8) self.assertLen(export_archive.trainable_variables, 6) @@ -608,7 +607,7 @@ def infer_fn(x): # Export with TF inference function as endpoint temp_filepath = os.path.join(self.get_temp_dir(), "my_model") - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint("serve", infer_fn) export_archive.write_out(temp_filepath) @@ -683,7 +682,7 @@ def infer_fn(x): # Export with TF inference function as endpoint temp_filepath = os.path.join(self.get_temp_dir(), "my_model") - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint("serve", infer_fn) export_archive.write_out(temp_filepath) @@ -707,7 +706,7 @@ def test_layer_export(self): ref_input = tf.random.normal((3, 10)) ref_output = layer(ref_input) # Build layer (important) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(layer) export_archive.add_endpoint( "call", @@ -729,7 +728,7 @@ def test_multi_input_output_functional_model(self): ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))] ref_outputs = model(ref_inputs) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "serve", @@ -759,7 +758,7 @@ def test_multi_input_output_functional_model(self): } ref_outputs = model(ref_inputs) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "serve", @@ -799,7 +798,7 @@ def test_multi_input_output_functional_model(self): # ref_input = tf.convert_to_tensor(["one two three four"]) # ref_output = model(ref_input) - # export_lib.export_saved_model(model, temp_filepath) + # saved_model.export_saved_model(model, temp_filepath) # revived_model = tf.saved_model.load(temp_filepath) # self.assertAllClose(ref_output, revived_model.serve(ref_input)) @@ -812,7 +811,7 @@ def test_track_multiple_layers(self): ref_input_2 = tf.random.normal((3, 5)) ref_output_2 = layer_2(ref_input_2) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.add_endpoint( "call_1", layer_1.call, @@ -835,7 +834,7 @@ def test_non_standard_layer_signature(self): x1 = tf.random.normal((3, 2, 2)) x2 = tf.random.normal((3, 2, 2)) ref_output = layer(x1, x2) # Build layer (important) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(layer) export_archive.add_endpoint( "call", @@ -856,7 +855,7 @@ def test_non_standard_layer_signature_with_kwargs(self): x1 = tf.random.normal((3, 2, 2)) x2 = tf.random.normal((3, 2, 2)) ref_output = layer(x1, x2) # Build layer (important) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(layer) export_archive.add_endpoint( "call", @@ -886,7 +885,7 @@ def test_variable_collection(self): ) # Test variable tracking - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -908,13 +907,13 @@ def test_export_saved_model_errors(self): # Model has not been built model = models.Sequential([layers.Dense(2)]) with self.assertRaisesRegex(ValueError, "It must be built"): - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) # Subclassed model has not been called model = get_model("subclass") model.build((2, 10)) with self.assertRaisesRegex(ValueError, "It must be called"): - export_lib.export_saved_model(model, temp_filepath) + saved_model.export_saved_model(model, temp_filepath) def test_export_archive_errors(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -922,7 +921,7 @@ def test_export_archive_errors(self): model(tf.random.normal((2, 3))) # Endpoint name reuse - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) export_archive.add_endpoint( "call", @@ -939,18 +938,18 @@ def test_export_archive_errors(self): ) # Write out with no endpoints - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) with self.assertRaisesRegex(ValueError, "No endpoints have been set"): export_archive.write_out(temp_filepath) # Invalid object type with self.assertRaisesRegex(ValueError, "Invalid resource type"): - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track("model") # Set endpoint with no input signature - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) with self.assertRaisesRegex( ValueError, "you must provide an `input_signature`" @@ -958,14 +957,14 @@ def test_export_archive_errors(self): export_archive.add_endpoint("call", model.__call__) # Set endpoint that has never been called - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) @tf.function() def my_endpoint(x): return model(x) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.track(model) with self.assertRaisesRegex( ValueError, "you must either provide a function" @@ -978,7 +977,7 @@ def test_export_no_assets(self): # Case where there are legitimately no assets. model = models.Sequential([layers.Flatten()]) model(tf.random.normal((2, 3))) - export_archive = export_lib.ExportArchive() + export_archive = saved_model.ExportArchive() export_archive.add_endpoint( "call", model.__call__, @@ -1000,133 +999,3 @@ def test_model_export_method(self, model_type): self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size revived_model.serve(tf.random.normal((6, 10))) - - -@pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="TFSM Layer reloading is only for the TF backend.", -) -class TestTFSMLayer(testing.TestCase): - def test_reloading_export_archive(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - - export_lib.export_saved_model(model, temp_filepath) - reloaded_layer = export_lib.TFSMLayer(temp_filepath) - self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7) - self.assertLen(reloaded_layer.weights, len(model.weights)) - self.assertLen( - reloaded_layer.trainable_weights, len(model.trainable_weights) - ) - self.assertLen( - reloaded_layer.non_trainable_weights, - len(model.non_trainable_weights), - ) - - # TODO(nkovela): Expand test coverage/debug fine-tuning and - # non-trainable use cases here. - - def test_reloading_default_saved_model(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - - tf.saved_model.save(model, temp_filepath) - reloaded_layer = export_lib.TFSMLayer( - temp_filepath, call_endpoint="serving_default" - ) - # The output is a dict, due to the nature of SavedModel saving. - new_output = reloaded_layer(ref_input) - self.assertAllClose( - new_output[list(new_output.keys())[0]], - ref_output, - atol=1e-7, - ) - self.assertLen(reloaded_layer.weights, len(model.weights)) - self.assertLen( - reloaded_layer.trainable_weights, len(model.trainable_weights) - ) - self.assertLen( - reloaded_layer.non_trainable_weights, - len(model.non_trainable_weights), - ) - - def test_call_training(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - utils.set_random_seed(1337) - model = models.Sequential( - [ - layers.Input((10,)), - layers.Dense(10), - layers.Dropout(0.99999), - ] - ) - export_archive = export_lib.ExportArchive() - export_archive.track(model) - export_archive.add_endpoint( - name="call_inference", - fn=lambda x: model(x, training=False), - input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], - ) - export_archive.add_endpoint( - name="call_training", - fn=lambda x: model(x, training=True), - input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], - ) - export_archive.write_out(temp_filepath) - reloaded_layer = export_lib.TFSMLayer( - temp_filepath, - call_endpoint="call_inference", - call_training_endpoint="call_training", - ) - inference_output = reloaded_layer( - tf.random.normal((1, 10)), training=False - ) - training_output = reloaded_layer( - tf.random.normal((1, 10)), training=True - ) - self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7) - self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7) - - def test_serialization(self): - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = get_model() - ref_input = tf.random.normal((3, 10)) - ref_output = model(ref_input) - - export_lib.export_saved_model(model, temp_filepath) - reloaded_layer = export_lib.TFSMLayer(temp_filepath) - - # Test reinstantiation from config - config = reloaded_layer.get_config() - rereloaded_layer = export_lib.TFSMLayer.from_config(config) - self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7) - - # Test whole model saving with reloaded layer inside - model = models.Sequential([reloaded_layer]) - temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") - model.save(temp_model_filepath, save_format="keras_v3") - reloaded_model = saving_lib.load_model( - temp_model_filepath, - custom_objects={"TFSMLayer": export_lib.TFSMLayer}, - ) - self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7) - - def test_errors(self): - # Test missing call endpoint - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) - export_lib.export_saved_model(model, temp_filepath) - with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): - export_lib.TFSMLayer(temp_filepath, call_endpoint="wrong") - - # Test missing call training endpoint - with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): - export_lib.TFSMLayer( - temp_filepath, - call_endpoint="serve", - call_training_endpoint="wrong", - ) diff --git a/keras/src/export/tfsm_layer.py b/keras/src/export/tfsm_layer.py new file mode 100644 index 000000000000..61859bf0fc22 --- /dev/null +++ b/keras/src/export/tfsm_layer.py @@ -0,0 +1,139 @@ +from keras.src import backend +from keras.src import layers +from keras.src.api_export import keras_export +from keras.src.export.saved_model import _list_variables_used_by_fns +from keras.src.utils.module_utils import tensorflow as tf + + +@keras_export("keras.layers.TFSMLayer") +class TFSMLayer(layers.Layer): + """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. + + Arguments: + filepath: `str` or `pathlib.Path` object. The path to the SavedModel. + call_endpoint: Name of the endpoint to use as the `call()` method + of the reloaded layer. If the SavedModel was created + via `model.export()`, + then the default endpoint name is `'serve'`. In other cases + it may be named `'serving_default'`. + + Example: + + ```python + model.export("path/to/artifact") + reloaded_layer = TFSMLayer("path/to/artifact") + outputs = reloaded_layer(inputs) + ``` + + The reloaded object can be used like a regular Keras layer, and supports + training/fine-tuning of its trainable weights. Note that the reloaded + object retains none of the internal structure or custom methods of the + original object -- it's a brand new layer created around the saved + function. + + **Limitations:** + + * Only call endpoints with a single `inputs` tensor argument + (which may optionally be a dict/tuple/list of tensors) are supported. + For endpoints with multiple separate input tensor arguments, consider + subclassing `TFSMLayer` and implementing a `call()` method with a + custom signature. + * If you need training-time behavior to differ from inference-time behavior + (i.e. if you need the reloaded object to support a `training=True` argument + in `__call__()`), make sure that the training-time call function is + saved as a standalone endpoint in the artifact, and provide its name + to the `TFSMLayer` via the `call_training_endpoint` argument. + """ + + def __init__( + self, + filepath, + call_endpoint="serve", + call_training_endpoint=None, + trainable=True, + name=None, + dtype=None, + ): + if backend.backend() != "tensorflow": + raise NotImplementedError( + "The TFSMLayer is only currently supported with the " + "TensorFlow backend." + ) + + # Initialize an empty layer, then add_weight() etc. as needed. + super().__init__(trainable=trainable, name=name, dtype=dtype) + + self._reloaded_obj = tf.saved_model.load(filepath) + + self.filepath = filepath + self.call_endpoint = call_endpoint + self.call_training_endpoint = call_training_endpoint + + # Resolve the call function. + if hasattr(self._reloaded_obj, call_endpoint): + # Case 1: it's set as an attribute. + self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint) + elif call_endpoint in self._reloaded_obj.signatures: + # Case 2: it's listed in the `signatures` field. + self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint] + else: + raise ValueError( + f"The endpoint '{call_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Select another endpoint via " + "the `call_endpoint` argument. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" + ) + + # Resolving the training function. + if call_training_endpoint: + if hasattr(self._reloaded_obj, call_training_endpoint): + self.call_training_endpoint_fn = getattr( + self._reloaded_obj, call_training_endpoint + ) + elif call_training_endpoint in self._reloaded_obj.signatures: + self.call_training_endpoint_fn = self._reloaded_obj.signatures[ + call_training_endpoint + ] + else: + raise ValueError( + f"The endpoint '{call_training_endpoint}' " + "is neither an attribute of the reloaded SavedModel, " + "nor an entry in the `signatures` field of " + "the reloaded SavedModel. Available endpoints for " + "this SavedModel: " + f"{list(self._reloaded_obj.signatures.keys())}" + ) + + # Add trainable and non-trainable weights from the call_endpoint_fn. + all_fns = [self.call_endpoint_fn] + if call_training_endpoint: + all_fns.append(self.call_training_endpoint_fn) + tvs, ntvs = _list_variables_used_by_fns(all_fns) + for v in tvs: + self._add_existing_weight(v) + for v in ntvs: + self._add_existing_weight(v) + self.built = True + + def _add_existing_weight(self, weight): + """Tracks an existing weight.""" + self._track_variable(weight) + + def call(self, inputs, training=False, **kwargs): + if training: + if self.call_training_endpoint: + return self.call_training_endpoint_fn(inputs, **kwargs) + return self.call_endpoint_fn(inputs, **kwargs) + + def get_config(self): + base_config = super().get_config() + config = { + # Note: this is not intended to be portable. + "filepath": self.filepath, + "call_endpoint": self.call_endpoint, + "call_training_endpoint": self.call_training_endpoint, + } + return {**base_config, **config} diff --git a/keras/src/export/tfsm_layer_test.py b/keras/src/export/tfsm_layer_test.py new file mode 100644 index 000000000000..13d49141d6f1 --- /dev/null +++ b/keras/src/export/tfsm_layer_test.py @@ -0,0 +1,145 @@ +import os + +import numpy as np +import pytest +import tensorflow as tf + +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src import utils +from keras.src.export import saved_model +from keras.src.export import tfsm_layer +from keras.src.export.saved_model_test import get_model +from keras.src.saving import saving_lib + + +@pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="TFSM Layer reloading is only for the TF backend.", +) +class TestTFSMLayer(testing.TestCase): + def test_reloading_export_archive(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + # TODO(nkovela): Expand test coverage/debug fine-tuning and + # non-trainable use cases here. + + def test_reloading_default_saved_model(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + tf.saved_model.save(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer( + temp_filepath, call_endpoint="serving_default" + ) + # The output is a dict, due to the nature of SavedModel saving. + new_output = reloaded_layer(ref_input) + self.assertAllClose( + new_output[list(new_output.keys())[0]], + ref_output, + atol=1e-7, + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) + + def test_call_training(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + utils.set_random_seed(1337) + model = models.Sequential( + [ + layers.Input((10,)), + layers.Dense(10), + layers.Dropout(0.99999), + ] + ) + export_archive = saved_model.ExportArchive() + export_archive.track(model) + export_archive.add_endpoint( + name="call_inference", + fn=lambda x: model(x, training=False), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.add_endpoint( + name="call_training", + fn=lambda x: model(x, training=True), + input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)], + ) + export_archive.write_out(temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer( + temp_filepath, + call_endpoint="call_inference", + call_training_endpoint="call_training", + ) + inference_output = reloaded_layer( + tf.random.normal((1, 10)), training=False + ) + training_output = reloaded_layer( + tf.random.normal((1, 10)), training=True + ) + self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7) + self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7) + + def test_serialization(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model() + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + saved_model.export_saved_model(model, temp_filepath) + reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath) + + # Test reinstantiation from config + config = reloaded_layer.get_config() + rereloaded_layer = tfsm_layer.TFSMLayer.from_config(config) + self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7) + + # Test whole model saving with reloaded layer inside + model = models.Sequential([reloaded_layer]) + temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras") + model.save(temp_model_filepath, save_format="keras_v3") + reloaded_model = saving_lib.load_model( + temp_model_filepath, + custom_objects={"TFSMLayer": tfsm_layer.TFSMLayer}, + ) + self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7) + + def test_errors(self): + # Test missing call endpoint + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) + saved_model.export_saved_model(model, temp_filepath) + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + tfsm_layer.TFSMLayer(temp_filepath, call_endpoint="wrong") + + # Test missing call training endpoint + with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): + tfsm_layer.TFSMLayer( + temp_filepath, + call_endpoint="serve", + call_training_endpoint="wrong", + ) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 2c2faac218a1..b54c91c9e193 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -6,6 +6,7 @@ from keras.src import backend from keras.src import constraints +from keras.src import export from keras.src import layers from keras.src import models from keras.src import ops @@ -14,7 +15,6 @@ from keras.src import saving from keras.src import testing from keras.src.backend.common import keras_tensor -from keras.src.export import export_lib class DenseTest(testing.TestCase): @@ -566,7 +566,7 @@ def test_quantize_int8_when_lora_enabled(self): ref_input = tf.random.normal((2, 8)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 ) @@ -738,7 +738,7 @@ def test_quantize_float8_fitting(self): ref_input = tf.random.normal((2, 8)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output) self.assertLen(reloaded_layer.weights, len(model.weights)) self.assertLen( diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 796cb37fd767..3fcecef0310c 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -6,6 +6,7 @@ from keras.src import backend from keras.src import constraints +from keras.src import export from keras.src import layers from keras.src import models from keras.src import ops @@ -13,7 +14,6 @@ from keras.src import random from keras.src import saving from keras.src import testing -from keras.src.export import export_lib class EinsumDenseTest(testing.TestCase): @@ -699,7 +699,7 @@ def test_quantize_int8_when_lora_enabled( ref_input = tf.random.normal(input_shape) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 ) @@ -878,7 +878,7 @@ def test_quantize_float8_fitting(self): ref_input = tf.random.normal((2, 3)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output) self.assertLen(reloaded_layer.weights, len(model.weights)) self.assertLen( diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py index ac4b6d6c8c74..784216c4cc80 100644 --- a/keras/src/layers/core/embedding_test.py +++ b/keras/src/layers/core/embedding_test.py @@ -6,11 +6,11 @@ from keras.src import backend from keras.src import constraints +from keras.src import export from keras.src import layers from keras.src import models from keras.src import ops from keras.src import saving -from keras.src.export import export_lib from keras.src.testing import test_case @@ -439,7 +439,7 @@ def test_quantize_when_lora_enabled(self): ref_input = tf.random.normal((32, 3)) ref_output = model(ref_input) model.export(temp_filepath, format="tf_saved_model") - reloaded_layer = export_lib.TFSMLayer(temp_filepath) + reloaded_layer = export.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 1de2ba0f2350..8e36bb20456b 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1472,9 +1472,19 @@ def _check_super_called(self): def _assert_input_compatibility(self, arg_0): if self.input_spec: - input_spec.assert_input_compatibility( - self.input_spec, arg_0, layer_name=self.name - ) + try: + input_spec.assert_input_compatibility( + self.input_spec, arg_0, layer_name=self.name + ) + except SystemError: + if backend.backend() == "torch": + # TODO: The torch backend failed the ONNX CI with the error: + # SystemError: returned a result with an exception set + # As a workaround, we are skipping this for now. + pass + else: + raise def _get_call_context(self): """Returns currently active `CallContext`.""" diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 832e0b35b369..46f103076544 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -470,15 +470,12 @@ def export( ): """Export the model as an artifact for inference. - **Note:** This feature is currently supported only with TensorFlow and - JAX backends. - **Note:** Currently, only `format="tf_saved_model"` is supported. - Args: filepath: `str` or `pathlib.Path` object. The path to save the artifact. - format: `str`. The export format. Supported value: - `"tf_saved_model"`. Defaults to `"tf_saved_model"`. + format: `str`. The export format. Supported values: + `"tf_saved_model"` and `"onnx"`. Defaults to + `"tf_saved_model"`. verbose: `bool`. Whether to print a message during export. Defaults to `True`. input_signature: Optional. Specifies the shape and dtype of the @@ -487,7 +484,7 @@ def export( not provided, it will be automatically computed. Defaults to `None`. **kwargs: Additional keyword arguments: - - Specific to the JAX backend: + - Specific to the JAX backend and `format="tf_saved_model"`: - `is_static`: Optional `bool`. Indicates whether `fn` is static. Set to `False` if `fn` involves state updates (e.g., RNG seeds and counters). @@ -498,7 +495,12 @@ def export( If `native_serialization` and `polymorphic_shapes` are not provided, they will be automatically computed. - Example: + **Note:** This feature is currently supported only with TensorFlow, JAX + and Torch backends. + + Examples: + + Here's how to export a TensorFlow SavedModel for inference. ```python # Export the model as a TensorFlow SavedModel artifact @@ -508,10 +510,25 @@ def export( reloaded_artifact = tf.saved_model.load("path/to/location") predictions = reloaded_artifact.serve(input_data) ``` + + Here's how to export an ONNX for inference. + + ```python + # Export the model as a ONNX artifact + model.export("path/to/location", format="onnx") + + # Load the artifact in a different process/environment + ort_session = onnxruntime.InferenceSession("path/to/location") + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) + } + predictions = ort_session.run(None, ort_inputs) + ``` """ - from keras.src.export import export_lib + from keras.src.export import export_onnx + from keras.src.export import export_saved_model - available_formats = ("tf_saved_model",) + available_formats = ("tf_saved_model", "onnx") if format not in available_formats: raise ValueError( f"Unrecognized format={format}. Supported formats are: " @@ -519,7 +536,15 @@ def export( ) if format == "tf_saved_model": - export_lib.export_saved_model( + export_saved_model( + self, + filepath, + verbose, + input_signature=input_signature, + **kwargs, + ) + elif format == "onnx": + export_onnx( self, filepath, verbose, diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 212fbad58871..eb83cad42356 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1219,6 +1219,10 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) self.assertListEqual(hist_keys, ref_keys) + @parameterized.named_parameters( + ("tf_saved_model", "tf_saved_model"), + ("onnx", "onnx"), + ) @pytest.mark.skipif( backend.backend() not in ("tensorflow", "jax", "torch"), reason=( @@ -1229,29 +1233,60 @@ def test_functional_deeply_nested_outputs_struct_losses(self): @pytest.mark.skipif( testing.jax_uses_gpu(), reason="Leads to core dumps on CI" ) - @pytest.mark.skipif( - testing.torch_uses_gpu(), reason="Leads to core dumps on CI" - ) - def test_export(self): - import tensorflow as tf + def test_export(self, export_format): + if export_format == "tf_saved_model" and testing.torch_uses_gpu(): + self.skipTest("Leads to core dumps on CI") temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = _get_model() - x1 = np.random.rand(1, 3) - x2 = np.random.rand(1, 3) + x1 = np.random.rand(1, 3).astype("float32") + x2 = np.random.rand(1, 3).astype("float32") ref_output = model([x1, x2]) - model.export(temp_filepath) - revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_output, revived_model.serve([x1, x2])) + model.export(temp_filepath, format=export_format) - # Test with a different batch size - if backend.backend() == "torch": - # TODO: Dynamic shape is not supported yet in the torch backend - return - revived_model.serve( - [np.concatenate([x1, x1], axis=0), np.concatenate([x2, x2], axis=0)] - ) + if export_format == "tf_saved_model": + import tensorflow as tf + + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve([x1, x2])) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + revived_model.serve( + [ + np.concatenate([x1, x1], axis=0), + np.concatenate([x2, x2], axis=0), + ] + ) + elif export_format == "onnx": + import onnxruntime + + ort_session = onnxruntime.InferenceSession(temp_filepath) + ort_inputs = { + k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2]) + } + self.assertAllClose( + ref_output, ort_session.run(None, ort_inputs)[0] + ) + + # Test with a different batch size + if backend.backend() == "torch": + # TODO: Dynamic shape is not supported yet in the torch backend + return + ort_inputs = { + k.name: v + for k, v in zip( + ort_session.get_inputs(), + [ + np.concatenate([x1, x1], axis=0), + np.concatenate([x2, x2], axis=0), + ], + ) + } + ort_session.run(None, ort_inputs) def test_export_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py index 1a1dbac619bf..190bc8dc72fe 100644 --- a/keras/src/utils/module_utils.py +++ b/keras/src/utils/module_utils.py @@ -58,3 +58,4 @@ def __repr__(self): ) optree = LazyModule("optree") dmtree = LazyModule("tree") +tf2onnx = LazyModule("tf2onnx") diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index a368d191f3de..7b1d2166f638 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,5 +1,6 @@ # Tensorflow cpu-only version (needed for testing). tensorflow-cpu~=2.18.0 +tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 25ce69eeb7d0..fed601f658f2 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,5 +1,6 @@ # Tensorflow with cuda support. tensorflow[and-cuda]~=2.18.0 +tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 14abde44bdb5..d165faa16280 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,5 +1,6 @@ # Tensorflow cpu-only version (needed for testing). tensorflow-cpu~=2.18.0 +tf2onnx # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index c1baf145d002..0973be4969aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ tensorflow-cpu~=2.18.0;sys_platform != 'darwin' tensorflow~=2.18.0;sys_platform == 'darwin' tf_keras +tf2onnx # Torch. # TODO: Pin to < 2.3.0 (GitHub issue #19602) From e1e35be04cf11b29a464e1332302b10a56dc3389 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:40:43 +0800 Subject: [PATCH 2/4] Add onnxruntime dep --- requirements-common.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements-common.txt b/requirements-common.txt index 2d1ec92d9118..51c682f9ef41 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,3 +20,5 @@ packaging # for tree_test.py dm_tree coverage!=7.6.5 # 7.6.5 breaks CI +# for onnx_test.py +onnxruntime From 3c16cd4ccb0295fd94975bd75a809fc4842287ef Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 2 Jan 2025 12:40:01 +0800 Subject: [PATCH 3/4] Update numpy dep --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 51c682f9ef41..4d47532deec7 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,7 +1,7 @@ namex>=0.0.8 ruff pytest -numpy +numpy<2.0.0 # TODO: Remove the restriction when tf2onnx supports numpy>2.0.0 scipy scikit-learn pandas From 5a94db210ad08bac5e3f27dbe751456f4d0014db Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 3 Jan 2025 23:29:06 +0800 Subject: [PATCH 4/4] Resolve comments --- keras/src/export/onnx.py | 2 +- keras/src/export/tfsm_layer_test.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/keras/src/export/onnx.py b/keras/src/export/onnx.py index acca68bdcc32..0a66192de2dd 100644 --- a/keras/src/export/onnx.py +++ b/keras/src/export/onnx.py @@ -35,7 +35,7 @@ def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs): **Note:** The dtype policy must be "float32" for the model. You can further optimize the ONNX artifact using the ONNX toolkit. Learn more here: - https://onnxruntime.ai/docs/performance/. + [https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/). **Note:** The dynamic shape feature is not yet supported with Torch backend. As a result, you must fully define the shapes of the inputs using diff --git a/keras/src/export/tfsm_layer_test.py b/keras/src/export/tfsm_layer_test.py index 13d49141d6f1..31cb1673cf10 100644 --- a/keras/src/export/tfsm_layer_test.py +++ b/keras/src/export/tfsm_layer_test.py @@ -38,9 +38,6 @@ def test_reloading_export_archive(self): len(model.non_trainable_weights), ) - # TODO(nkovela): Expand test coverage/debug fine-tuning and - # non-trainable use cases here. - def test_reloading_default_saved_model(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = get_model()