From a1739266512f2e892a2e2fbf2b15eb0b017c4ee0 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Mon, 3 Feb 2025 14:58:39 +0100 Subject: [PATCH 1/2] [Transform] Introduce ComposedTransformation This allows to exhaustively apply a sequence of transformations until no individual transformation modifies the graph anymore. --- src/qonnx/transformation/composed.py | 66 ++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 src/qonnx/transformation/composed.py diff --git a/src/qonnx/transformation/composed.py b/src/qonnx/transformation/composed.py new file mode 100644 index 00000000..abe410d7 --- /dev/null +++ b/src/qonnx/transformation/composed.py @@ -0,0 +1,66 @@ +# Copies (deep-copies) python objects +import copy + +# QONNX wrapper of ONNX model graphs +from qonnx.core.modelwrapper import ModelWrapper + +# Base class for all QONNX graph transformations and some basic cleanup +# transformations +from qonnx.transformation.general import ( + Transformation, + GiveUniqueNodeNames, + GiveReadableTensorNames, +) +# Cleanup transformations removing identities like multiplication by one or +# addition of zero +from qonnx.transformation.remove import RemoveIdentityOps +# QONNX graph transformations for annotating the graph with datatype and shape +# information +from qonnx.transformation.infer_datatypes import InferDataTypes +from qonnx.transformation.infer_shapes import InferShapes + + +# Composes graph transformations such that each individual transformation as +# well as the whole sequence is applied exhaustively +class ComposedTransformation(Transformation): + # Initializes the transformation given a list of transformations + def __init__(self, transformations: list[Transformation]): + # Initialize the transformation base class + super().__init__() + # Register the list of transformations to be applied in apply() + self.transformations = transformations + + # Applies the transform to a whole model graph + def apply(self, model: ModelWrapper): # noqa + # Keep track of whether the graph has been modified + graph_modified = False + # Iterate all transformations to be applied + for transformation in self.transformations: + # Start each transformation on a deep copy of the model to mimic the + # behavior of ModelWrapper.transform() + model = copy.deepcopy(model) + # Exhaustively apply the transformation until it no longer modifies + # the graph + while True: + # Apply the transformation once, reporting back whether any node + # or pattern has been modified + model, _graph_modified = transformation.apply(model) + # Keep track whether the graph has been modified at least once + graph_modified = graph_modified or _graph_modified + # Break the loop if this transformation did not change anything + if not _graph_modified: + break + # Apply the cleanup transformations of the ModelWrapper + model.cleanup() + # Apply some further cleanup transformations to the model graph + # removing some clutter and keeping all names readable and ordered + # at any time + model = model.transform(RemoveIdentityOps()) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + # Return the transformed model and indicate whether the graph actually + # has been transformed by at least one transformation so the whole + # sequence of transformations will be reapplied + return model, graph_modified From 54c6722b2ccca11861cbb25813d298094252f367 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Mon, 3 Feb 2025 15:21:43 +0100 Subject: [PATCH 2/2] Fix some linitng issues --- .isort.cfg | 4 ++++ src/qonnx/transformation/composed.py | 24 ++++++++++++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 5d85f60a..09b99040 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -9,3 +9,7 @@ sections=FUTURE,STDLIB,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER default_section=THIRDPARTY multi_line_output=3 profile=black +ignore_comments=true +ignore_whitespace=true +honor_noqa=true +use_parentheses=true diff --git a/src/qonnx/transformation/composed.py b/src/qonnx/transformation/composed.py index abe410d7..a69ad677 100644 --- a/src/qonnx/transformation/composed.py +++ b/src/qonnx/transformation/composed.py @@ -4,21 +4,25 @@ # QONNX wrapper of ONNX model graphs from qonnx.core.modelwrapper import ModelWrapper -# Base class for all QONNX graph transformations and some basic cleanup -# transformations -from qonnx.transformation.general import ( - Transformation, - GiveUniqueNodeNames, - GiveReadableTensorNames, -) -# Cleanup transformations removing identities like multiplication by one or -# addition of zero -from qonnx.transformation.remove import RemoveIdentityOps # QONNX graph transformations for annotating the graph with datatype and shape # information from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes +# Cleanup transformations removing identities like multiplication by one or +# addition of zero +from qonnx.transformation.remove import RemoveIdentityOps + +# Base class for all QONNX graph transformations and some basic cleanup +# transformations +# fmt: off +from qonnx.transformation.general import ( # isort: skip + GiveReadableTensorNames, GiveUniqueNodeNames, Transformation +) + + +# fmt: on + # Composes graph transformations such that each individual transformation as # well as the whole sequence is applied exhaustively