Skip to content

Commit aa270a2

Browse files
authored
Hacky fix for dictionary output with tf 2.14 (#933)
1 parent ff60e34 commit aa270a2

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

keras_core/ops/function.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from keras_core.api_export import keras_core_export
66
from keras_core.backend import KerasTensor
7+
from keras_core.backend.config import backend
78
from keras_core.ops.operation import Operation
89
from keras_core.utils.nest import pack_sequence_as
910

@@ -46,10 +47,21 @@ class Function(Operation):
4647
def __init__(self, inputs, outputs, name=None):
4748
super().__init__(name=name)
4849

50+
if backend() == "tensorflow":
51+
# Temporary work around for
52+
# https://github.com/keras-team/keras-core/issues/931
53+
# This stop tensorflow from wrapping tf.function output in a
54+
# _DictWrapper object.
55+
_self_setattr_tracking = getattr(
56+
self, "_self_setattr_tracking", True
57+
)
58+
self._self_setattr_tracking = False
4959
self._inputs_struct = tree.map_structure(lambda x: x, inputs)
5060
self._outputs_struct = tree.map_structure(lambda x: x, outputs)
5161
self._inputs = tree.flatten(inputs)
5262
self._outputs = tree.flatten(outputs)
63+
if backend() == "tensorflow":
64+
self._self_setattr_tracking = _self_setattr_tracking
5365

5466
(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
5567
self._inputs, self._outputs

0 commit comments

Comments
 (0)