Skip to content

Refactor keras/src/export/export_lib and add export_onnx #20710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ then
# Raise error if GPU is not detected.
python3 -c 'import torch;assert torch.cuda.is_available()'

# TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH
pytest keras --ignore keras/src/applications \
--cov=keras \
--cov-config=pyproject.toml
Expand Down
2 changes: 1 addition & 1 deletion keras/api/_tf_keras/keras/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
since your modifications would be overwritten.
"""

from keras.src.export.export_lib import ExportArchive
from keras.src.export.saved_model import ExportArchive
2 changes: 1 addition & 1 deletion keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
since your modifications would be overwritten.
"""

from keras.src.export.export_lib import TFSMLayer
from keras.src.export.tfsm_layer import TFSMLayer
from keras.src.layers import deserialize
from keras.src.layers import serialize
from keras.src.layers.activations.activation import Activation
Expand Down
2 changes: 1 addition & 1 deletion keras/api/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
since your modifications would be overwritten.
"""

from keras.src.export.export_lib import ExportArchive
from keras.src.export.saved_model import ExportArchive
2 changes: 1 addition & 1 deletion keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
since your modifications would be overwritten.
"""

from keras.src.export.export_lib import TFSMLayer
from keras.src.export.tfsm_layer import TFSMLayer
from keras.src.layers import deserialize
from keras.src.layers import serialize
from keras.src.layers.activations.activation import Activation
Expand Down
24 changes: 5 additions & 19 deletions keras/src/backend/torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import torch

from keras.src import backend
from keras.src import ops
from keras.src import tree
from keras.src.export.export_utils import convert_spec_to_tensor
from keras.src.utils.module_utils import tensorflow as tf
from keras.src.utils.module_utils import torch_xla

Expand Down Expand Up @@ -36,23 +35,10 @@ def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
f"Received: resource={resource} (of type {type(resource)})"
)

def _check_input_signature(input_spec):
for s in tree.flatten(input_spec.shape):
if s is None:
raise ValueError(
"The shape in the `input_spec` must be fully "
f"specified. Received: input_spec={input_spec}"
)

def _to_torch_tensor(x, replace_none_number=1):
shape = backend.standardize_shape(x.shape)
shape = tuple(
s if s is not None else replace_none_number for s in shape
)
return ops.ones(shape, x.dtype)

tree.map_structure(_check_input_signature, input_signature)
sample_inputs = tree.map_structure(_to_torch_tensor, input_signature)
sample_inputs = tree.map_structure(
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
input_signature,
)
sample_inputs = tuple(sample_inputs)

# Ref: torch_xla.tf_saved_model_integration
Expand Down
5 changes: 4 additions & 1 deletion keras/src/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from keras.src.export.export_lib import ExportArchive
from keras.src.export.onnx import export_onnx
from keras.src.export.saved_model import ExportArchive
from keras.src.export.saved_model import export_saved_model
from keras.src.export.tfsm_layer import TFSMLayer
105 changes: 105 additions & 0 deletions keras/src/export/export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import ops
from keras.src import tree
from keras.src.utils.module_utils import tensorflow as tf


def get_input_signature(model):
if not isinstance(model, models.Model):
raise TypeError(
"The model must be a `keras.Model`. "
f"Received: model={model} of the type {type(model)}"
)
if not model.built:
raise ValueError(
"The model provided has not yet been built. It must be built "
"before export."
)
if isinstance(model, (models.Functional, models.Sequential)):
input_signature = tree.map_structure(make_input_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
else:
input_signature = _infer_input_signature_from_model(model)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)
return input_signature


def _infer_input_signature_from_model(model):
shapes_dict = getattr(model, "_build_shapes_dict", None)
if not shapes_dict:
return None

def _make_input_spec(structure):
# We need to turn wrapper structures like TrackingDict or _DictWrapper
# into plain Python structures because they don't work with jax2tf/JAX.
if isinstance(structure, dict):
return {k: _make_input_spec(v) for k, v in structure.items()}
elif isinstance(structure, tuple):
if all(isinstance(d, (int, type(None))) for d in structure):
return layers.InputSpec(
shape=(None,) + structure[1:], dtype=model.input_dtype
)
return tuple(_make_input_spec(v) for v in structure)
elif isinstance(structure, list):
if all(isinstance(d, (int, type(None))) for d in structure):
return layers.InputSpec(
shape=[None] + structure[1:], dtype=model.input_dtype
)
return [_make_input_spec(v) for v in structure]
else:
raise ValueError(
f"Unsupported type {type(structure)} for {structure}"
)

return [_make_input_spec(value) for value in shapes_dict.values()]


def make_input_spec(x):
if isinstance(x, layers.InputSpec):
if x.shape is None or x.dtype is None:
raise ValueError(
"The `shape` and `dtype` must be provided. " f"Received: x={x}"
)
input_spec = x
elif isinstance(x, backend.KerasTensor):
shape = (None,) + backend.standardize_shape(x.shape)[1:]
dtype = backend.standardize_dtype(x.dtype)
input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name)
elif backend.is_tensor(x):
shape = (None,) + backend.standardize_shape(x.shape)[1:]
dtype = backend.standardize_dtype(x.dtype)
input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None)
else:
raise TypeError(
f"Unsupported x={x} of the type ({type(x)}). Supported types are: "
"`keras.InputSpec`, `keras.KerasTensor` and backend tensor."
)
return input_spec


def make_tf_tensor_spec(x):
if isinstance(x, tf.TensorSpec):
tensor_spec = x
else:
input_spec = make_input_spec(x)
tensor_spec = tf.TensorSpec(
input_spec.shape, dtype=input_spec.dtype, name=input_spec.name
)
return tensor_spec


def convert_spec_to_tensor(spec, replace_none_number=None):
shape = backend.standardize_shape(spec.shape)
if replace_none_number is not None:
replace_none_number = int(replace_none_number)
shape = tuple(
s if s is not None else replace_none_number for s in shape
)
return ops.ones(shape, spec.dtype)
162 changes: 162 additions & 0 deletions keras/src/export/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import pathlib
import tempfile

from keras.src import backend
from keras.src import tree
from keras.src.export.export_utils import convert_spec_to_tensor
from keras.src.export.export_utils import get_input_signature
from keras.src.export.saved_model import export_saved_model
from keras.src.utils.module_utils import tensorflow as tf


def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs):
"""Export the model as a ONNX artifact for inference.

This method lets you export a model to a lightweight ONNX artifact
that contains the model's forward pass only (its `call()` method)
and can be served via e.g. ONNX Runtime.

The original code of the model (including any custom layers you may
have used) is *no longer* necessary to reload the artifact -- it is
entirely standalone.

Args:
filepath: `str` or `pathlib.Path` object. The path to save the artifact.
verbose: `bool`. Whether to print a message during export. Defaults to
True`.
input_signature: Optional. Specifies the shape and dtype of the model
inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
`backend.KerasTensor`, or backend tensor. If not provided, it will
be automatically computed. Defaults to `None`.
**kwargs: Additional keyword arguments.

**Note:** This feature is currently supported only with TensorFlow, JAX and
Torch backends.

**Note:** The dtype policy must be "float32" for the model. You can further
optimize the ONNX artifact using the ONNX toolkit. Learn more here:
[https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/).

**Note:** The dynamic shape feature is not yet supported with Torch
backend. As a result, you must fully define the shapes of the inputs using
`input_signature`. If `input_signature` is not provided, all instances of
`None` (such as the batch size) will be replaced with `1`.

Example:

```python
# Export the model as a ONNX artifact
model.export("path/to/location", format="onnx")

# Load the artifact in a different process/environment
ort_session = onnxruntime.InferenceSession("path/to/location")
ort_inputs = {
k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
}
predictions = ort_session.run(None, ort_inputs)
```
"""
if input_signature is None:
input_signature = get_input_signature(model)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)

if backend.backend() in ("tensorflow", "jax"):
working_dir = pathlib.Path(filepath).parent
with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir:
if backend.backend() == "jax":
kwargs = _check_jax_kwargs(kwargs)
export_saved_model(
model,
temp_dir,
verbose,
input_signature,
**kwargs,
)
saved_model_to_onnx(temp_dir, filepath, model.name)

elif backend.backend() == "torch":
import torch

sample_inputs = tree.map_structure(
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
input_signature,
)
sample_inputs = tuple(sample_inputs)
# TODO: Make dict model exportable.
if any(isinstance(x, dict) for x in sample_inputs):
raise ValueError(
"Currently, `export_onnx` in the torch backend doesn't support "
"dictionaries as inputs."
)

# Convert to ONNX using TorchScript-based ONNX Exporter.
# TODO: Use TorchDynamo-based ONNX Exporter once
# `torch.onnx.dynamo_export()` supports Keras models.
torch.onnx.export(model, sample_inputs, filepath, verbose=verbose)
else:
raise NotImplementedError(
"`export_onnx` is only compatible with TensorFlow, JAX and "
"Torch backends."
)


def _check_jax_kwargs(kwargs):
kwargs = kwargs.copy()
if "is_static" not in kwargs:
kwargs["is_static"] = True
if "jax2tf_kwargs" not in kwargs:
# TODO: These options will be deprecated in JAX. We need to
# find another way to export ONNX.
kwargs["jax2tf_kwargs"] = {
"enable_xla": False,
"native_serialization": False,
}
if kwargs["is_static"] is not True:
raise ValueError(
"`is_static` must be `True` in `kwargs` when using the jax "
"backend."
)
if kwargs["jax2tf_kwargs"]["enable_xla"] is not False:
raise ValueError(
"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` "
"when using the jax backend."
)
if kwargs["jax2tf_kwargs"]["native_serialization"] is not False:
raise ValueError(
"`native_serialization` must be `False` in "
"`kwargs['jax2tf_kwargs']` when using the jax backend."
)
return kwargs


def saved_model_to_onnx(saved_model_dir, filepath, name):
from keras.src.utils.module_utils import tf2onnx

# Convert to ONNX using `tf2onnx` library.
(graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = (
tf2onnx.tf_loader.from_saved_model(
saved_model_dir,
None,
None,
return_initialized_tables=True,
return_tensors_to_rename=True,
)
)

with tf.device("/cpu:0"):
_ = tf2onnx.convert._convert_common(
graph_def,
name=name,
target=[],
custom_op_handlers={},
extra_opset=[],
input_names=inputs,
output_names=outputs,
tensors_to_rename=tensors_to_rename,
initialized_tables=initialized_tables,
output_path=filepath,
)
Loading
Loading