diff --git a/export/orbax/export/export_manager.py b/export/orbax/export/export_manager.py index c0a9d12f1..f2b1873de 100644 --- a/export/orbax/export/export_manager.py +++ b/export/orbax/export/export_manager.py @@ -96,6 +96,7 @@ def save( additional signatures to export. tree_verity_options: Settings to enable model hashing and signing via + inference_converter_options: Options for the TPU Inference Converter V2. """ self._serialization_functions.save( model_path=model_path, diff --git a/export/orbax/export/export_manager_obm_test.py b/export/orbax/export/export_manager_obm_test.py index 32640a9a3..e21fb2f42 100644 --- a/export/orbax/export/export_manager_obm_test.py +++ b/export/orbax/export/export_manager_obm_test.py @@ -18,7 +18,6 @@ from absl.testing import absltest from absl.testing import parameterized - from orbax.export import constants from orbax.export import export_manager from orbax.export import export_testing_utils diff --git a/export/orbax/export/obm_export.py b/export/orbax/export/obm_export.py index 1378b8eab..ac57d9dd5 100644 --- a/export/orbax/export/obm_export.py +++ b/export/orbax/export/obm_export.py @@ -36,6 +36,7 @@ from orbax.export.modules import obm_module import tensorflow as tf + _obm_export_config = config.config diff --git a/export/orbax/export/tensorflow_export.py b/export/orbax/export/tensorflow_export.py index 62e7c7ee9..bb5cc2f4c 100644 --- a/export/orbax/export/tensorflow_export.py +++ b/export/orbax/export/tensorflow_export.py @@ -15,6 +15,7 @@ """Export class that implements the save and load abstract class defined in Export Base for use with the TensorFlow SavedModel export format.""" from collections.abc import Callable, Mapping, Sequence +import os from typing import Any from absl import logging @@ -85,9 +86,14 @@ def save( if signature_overrides: serving_signatures.update(signature_overrides) + converter_options = kwargs.get('inference_converter_options') + tf_model_path = ( + os.path.join(model_path, 'tmp') if converter_options else model_path + ) + tf.saved_model.save( self._tf_module, - model_path, + tf_model_path, serving_signatures, options=save_options, )