|
4 | 4 |
|
5 | 5 | from keras_core.api_export import keras_core_export
|
6 | 6 | from keras_core.backend import KerasTensor
|
| 7 | +from keras_core.backend.config import backend |
7 | 8 | from keras_core.ops.operation import Operation
|
8 | 9 | from keras_core.utils.nest import pack_sequence_as
|
9 | 10 |
|
@@ -46,10 +47,21 @@ class Function(Operation):
|
46 | 47 | def __init__(self, inputs, outputs, name=None):
|
47 | 48 | super().__init__(name=name)
|
48 | 49 |
|
| 50 | + if backend() == "tensorflow": |
| 51 | + # Temporary work around for |
| 52 | + # https://github.com/keras-team/keras-core/issues/931 |
| 53 | + # This stop tensorflow from wrapping tf.function output in a |
| 54 | + # _DictWrapper object. |
| 55 | + _self_setattr_tracking = getattr( |
| 56 | + self, "_self_setattr_tracking", True |
| 57 | + ) |
| 58 | + self._self_setattr_tracking = False |
49 | 59 | self._inputs_struct = tree.map_structure(lambda x: x, inputs)
|
50 | 60 | self._outputs_struct = tree.map_structure(lambda x: x, outputs)
|
51 | 61 | self._inputs = tree.flatten(inputs)
|
52 | 62 | self._outputs = tree.flatten(outputs)
|
| 63 | + if backend() == "tensorflow": |
| 64 | + self._self_setattr_tracking = _self_setattr_tracking |
53 | 65 |
|
54 | 66 | (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
|
55 | 67 | self._inputs, self._outputs
|
|
0 commit comments