Skip to content

Infrastructure for saving/loading hls4ml models #1158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/api/serialization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
============================
Saving/Loading hls4ml models
============================

``hls4ml`` model objects (instances of ``ModelGraph`` class) can be saved to disk and loaded at a later stage. The saved model doesn't require original Keras/PyTorch/ONNX model for loading.

To save/load a model use the following API:

.. code-block:: python

from hls4ml.converters import convert_from_keras_model, load_saved_model

model = convert_from_keras_model(keras_model, ...)

# Save a model to some path
model.save('some/path/my_hls4ml_model.fml')

# Load a model from a file
loaded_model = load_saved_model('some/path/my_hls4ml_model.fml')


Saved model will have a ``.fml`` extension, but is in fact a gzipped tar archive. Loaded model can be used in the same way as the original one. This includes modification of certain config parameters, for example output directory, layer reuse factor etc.

Linking with existing project
=============================

Once the project has been written to disk with ``ModelGraph.write()``, it can also be linked with at later stage. Similarly to loading a saved model, this feature allows skipping the conversion step. Additionally, it may be used to test manual changes to the generated project.

Linking function will create a special instance of ``ModelGraph`` that only allows calls to ``compile()``, ``predict()`` and ``build()``. Other calls to the ``ModelGraph`` instance are disabled.

To link a model use the following API:

.. code-block:: python

from hls4ml.converters import convert_from_keras_model, link_existing_project

model = convert_from_keras_model(keras_model, output_dir='/some/path/', ...)

# Generate the project files and write them to some path
model.write()

# Later on, link this path to the Python runtime
linked_model = link_existing_project('some/path/')
linked_model.compile()
linked_model.predict(...)
linked_model.build(...)
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
api/concepts
api/configuration
api/command
api/serialization

.. toctree::
:hidden:
Expand Down
6 changes: 5 additions & 1 deletion hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,12 @@ def create_layer_class(self, layer_class):
if issubclass(layer_class, cls):
new_attrubutes.extend(attributes)

layer_cls_fqn = layer_class.__module__ + '.' + layer_class.__qualname__

return type(
self.name + layer_class.__name__, (layer_class,), {'_expected_attributes': new_attrubutes, '_wrapped': True}
self.name + layer_class.__name__,
(layer_class,),
{'_expected_attributes': new_attrubutes, '_wrapped': layer_cls_fqn},
)

def compile(self, model):
Expand Down
33 changes: 26 additions & 7 deletions hls4ml/backends/fpga/fpga_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self, type_map, prefix):
def convert(self, precision_type):
type_cls = type(precision_type)
type_cls_name = type_cls.__name__
type_cls_fqn = type_cls.__module__ + '.' + type_cls.__qualname__

# If the type is already converted, do nothing
if type_cls_name.startswith(self.prefix):
Expand All @@ -111,7 +112,9 @@ def convert(self, precision_type):
definition_cls = self.type_map.get(type_cls, None)

if definition_cls is not None:
precision_type.__class__ = type(self.prefix + type_cls_name, (type_cls, definition_cls), {})
precision_type.__class__ = type(
self.prefix + type_cls_name, (type_cls, definition_cls), {'_wrapped': type_cls_fqn}
)
return precision_type
else:
raise Exception(f'Cannot convert precision type to {self.prefix}: {precision_type.__class__.__name__}')
Expand Down Expand Up @@ -206,6 +209,7 @@ def __init__(self, precision_converter):
def convert(self, atype):
type_cls = type(atype)
type_cls_name = type_cls.__name__
type_cls_fqn = type_cls.__module__ + '.' + type_cls.__qualname__

# If the type is already converted, do nothing
if type_cls_name.startswith('HLS'):
Expand All @@ -214,7 +218,7 @@ def convert(self, atype):
conversion_cls = self.type_map.get(type_cls, None)

if conversion_cls is not None:
atype.__class__ = type('HLS' + type_cls_name, (type_cls, conversion_cls), {})
atype.__class__ = type('HLS' + type_cls_name, (type_cls, conversion_cls), {'_wrapped': type_cls_fqn})
atype.convert_precision(self.precision_converter)
return atype
else:
Expand Down Expand Up @@ -246,8 +250,11 @@ def convert(self, tensor_var, pragma='partition'):

tensor_var.pragma = pragma
tensor_var.type = self.type_converter.convert(tensor_var.type)
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__

tensor_var.__class__ = type(self.prefix + 'ArrayVariable', (type(tensor_var), self.definition_cls), {})
tensor_var.__class__ = type(
self.prefix + 'ArrayVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
)
return tensor_var


Expand All @@ -273,8 +280,11 @@ def convert(self, tensor_var, pragma='partition', struct_name=None):
tensor_var.struct_name = str(struct_name)
tensor_var.member_name = tensor_var.name
tensor_var.name = tensor_var.struct_name + '.' + tensor_var.member_name
type_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__

tensor_var.__class__ = type(self.prefix + 'StructMemberVariable', (type(tensor_var), self.definition_cls), {})
tensor_var.__class__ = type(
self.prefix + 'StructMemberVariable', (type(tensor_var), self.definition_cls), {'_wrapped': type_cls_fqn}
)
return tensor_var


Expand All @@ -299,8 +309,11 @@ def convert(self, tensor_var, n_pack=1, depth=0):
tensor_var.type = self.type_converter.convert(
PackedType(tensor_var.type.name, tensor_var.type.precision, tensor_var.shape[-1], n_pack)
)
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__

tensor_var.__class__ = type(self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {})
tensor_var.__class__ = type(
self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
)
return tensor_var


Expand All @@ -318,8 +331,11 @@ def convert(self, tensor_var, n_pack=1, depth=0):
tensor_var.type = self.type_converter.convert(
PackedType(tensor_var.type.name, tensor_var.type.precision, tensor_var.input_var.shape[-1], n_pack)
)
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__

tensor_var.__class__ = type(self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {})
tensor_var.__class__ = type(
self.prefix + 'StreamVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
)
return tensor_var


Expand All @@ -344,8 +360,11 @@ def convert(self, weight_var):
weight_var.weight_class = weight_var.__class__.__name__
weight_var.storage = 'register'
weight_var.type = self.type_converter.convert(weight_var.type)
tensor_cls_fqn = weight_var.__class__.__module__ + '.' + weight_var.__class__.__qualname__

weight_var.__class__ = type('StaticWeightVariable', (type(weight_var), StaticWeightVariableDefinition), {})
weight_var.__class__ = type(
'StaticWeightVariable', (type(weight_var), StaticWeightVariableDefinition), {'_wrapped': tensor_cls_fqn}
)
return weight_var


Expand Down
10 changes: 8 additions & 2 deletions hls4ml/backends/oneapi/oneapi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,11 @@ def convert(self, tensor_var, pragma='', depth=0, n_pack=1):
# pipe_name and pipe_id are only used for io_stream and interface variables in io_parallel
tensor_var.pipe_name = f'{convert_to_pascal_case(tensor_var.name)}Pipe'
tensor_var.pipe_id = f'{convert_to_pascal_case(tensor_var.name)}PipeID'
tensor_cls_fqn = tensor_var.__class__.__module__ + '.' + tensor_var.__class__.__qualname__

tensor_var.__class__ = type(self.prefix + 'AggregateArrayVariable', (type(tensor_var), self.definition_cls), {})
tensor_var.__class__ = type(
self.prefix + 'AggregateArrayVariable', (type(tensor_var), self.definition_cls), {'_wrapped': tensor_cls_fqn}
)
return tensor_var


Expand Down Expand Up @@ -255,9 +258,12 @@ def convert(self, weight_var):
weight_var.type = self.type_converter.convert(
PackedType(weight_var.name + '_t', weight_var.type.precision, weight_var.data_length, 1)
)
weight_cls_fqn = weight_var.__class__.__module__ + '.' + weight_var.__class__.__qualname__

weight_var.__class__ = type(
'OneAPIStaticWeightVariable', (type(weight_var), OneAPIStaticWeightVariableDefinition), {}
'OneAPIStaticWeightVariable',
(type(weight_var), OneAPIStaticWeightVariableDefinition),
{'_wrapped': weight_cls_fqn},
)
return weight_var

Expand Down
40 changes: 39 additions & 1 deletion hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from hls4ml.model import ModelGraph
from hls4ml.utils.config import create_config
from hls4ml.utils.dependency import requires
from hls4ml.utils.link import FilesystemModelGraph
from hls4ml.utils.serialization import deserialize_model
from hls4ml.utils.symbolic_utils import LUTFunction

# ----------Layer handling register----------#
Expand Down Expand Up @@ -464,6 +466,42 @@ def convert_from_symbolic_expression(

config['HLSConfig'] = {'Model': {'Precision': precision, 'ReuseFactor': 1}}

hls_model = ModelGraph(config, layer_list)
hls_model = ModelGraph.from_layer_list(config, layer_list)

return hls_model


def link_existing_project(project_dir):
"""Create a stripped-down ModelGraph from an existing project previously generated by hls4ml.

The returned ModelGraph will only allow compile(), predict() and build() functions to be invoked.

Args:
project_dir (str): Path to the existing HLS project previously generated with hls4ml.

Returns:
FilesystemModelGraph: hls4ml model.
"""
return FilesystemModelGraph(project_dir)


def load_saved_model(file_path, output_dir=None):
"""
Loads an hls4ml model from a compressed file format (.fml).

See `hls4ml.utils.serialization.deserialize_model` for more details.

Args:
file_path (str or pathlib.Path): The path to the serialized model file (.fml).
output_dir (str or pathlib.Path, optional): The directory where extracted
testbench data files will be saved. If not specified, the files will
be restored to the same directory as the `.fml` file.

Returns:
ModelGraph: The deserialized hls4ml model.

Raises:
FileNotFoundError: If the specified `.fml` file does not exist.
OSError: If an I/O error occurs during extraction or file operations.
"""
return deserialize_model(file_path, output_dir=output_dir)
2 changes: 1 addition & 1 deletion hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,5 +360,5 @@ def keras_to_hls(config):
model_arch, reader = get_model_arch(config)
layer_list, input_layers, output_layers, _ = parse_keras_model(model_arch, reader)
print('Creating HLS model')
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)
hls_model = ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
return hls_model
2 changes: 1 addition & 1 deletion hls4ml/converters/onnx_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,5 @@ def onnx_to_hls(config):
#################

print('Creating HLS model')
hls_model = ModelGraph(config, layer_list, input_layers, output_layers)
hls_model = ModelGraph.from_layer_list(config, layer_list, input_layers, output_layers)
return hls_model
2 changes: 1 addition & 1 deletion hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,5 +426,5 @@ def parse_pytorch_model(config, verbose=True):
def pytorch_to_hls(config):
layer_list, input_layers, output_layers = parse_pytorch_model(config)
print('Creating HLS model')
hls_model = ModelGraph(config, layer_list, inputs=input_layers, outputs=output_layers)
hls_model = ModelGraph.from_layer_list(config, layer_list, inputs=input_layers, outputs=output_layers)
return hls_model
Loading
Loading