Skip to content

Commit b491c86

Browse files
rtg0795hertschuhmattdangerw
authored
Version bump to 3.11.3 (#21607)
* Disable `torch.load` in `TorchModuleWrapper` when in safe mode. (#21575) Raise an exception and explain the user about the risks. * Propagate `safe_mode` flag to legacy h5 loading code. (#21602) Also: - made various error messages related to `safe_mode` more consistent - removed no-op renaming code in legacy saving - uncommented unit tests in `serialization_lib_test.py` * Fix GRU with return_state=True on tf backend with cuda (#21603) * Version bump to 3.11.3 --------- Co-authored-by: hertschuh <[email protected]> Co-authored-by: Matt Watson <[email protected]>
1 parent 251ac34 commit b491c86

File tree

14 files changed

+140
-85
lines changed

14 files changed

+140
-85
lines changed

keras/src/backend/tensorflow/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _cudnn_gru(
778778
return (
779779
last_output,
780780
outputs,
781-
state,
781+
[state],
782782
)
783783

784784

keras/src/layers/core/lambda_layer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,15 @@ def _serialize_function_to_config(self, fn):
167167
)
168168

169169
@staticmethod
170-
def _raise_for_lambda_deserialization(arg_name, safe_mode):
170+
def _raise_for_lambda_deserialization(safe_mode):
171171
if safe_mode:
172172
raise ValueError(
173-
f"The `{arg_name}` of this `Lambda` layer is a Python lambda. "
174-
"Deserializing it is unsafe. If you trust the source of the "
175-
"config artifact, you can override this error "
176-
"by passing `safe_mode=False` "
177-
"to `from_config()`, or calling "
173+
"Requested the deserialization of a `Lambda` layer whose "
174+
"`function` is a Python lambda. This carries a potential risk "
175+
"of arbitrary code execution and thus it is disallowed by "
176+
"default. If you trust the source of the artifact, you can "
177+
"override this error by passing `safe_mode=False` to the "
178+
"loading function, or calling "
178179
"`keras.config.enable_unsafe_deserialization()."
179180
)
180181

@@ -187,7 +188,7 @@ def from_config(cls, config, custom_objects=None, safe_mode=None):
187188
and "class_name" in fn_config
188189
and fn_config["class_name"] == "__lambda__"
189190
):
190-
cls._raise_for_lambda_deserialization("function", safe_mode)
191+
cls._raise_for_lambda_deserialization(safe_mode)
191192
inner_config = fn_config["config"]
192193
fn = python_utils.func_load(
193194
inner_config["code"],
@@ -206,7 +207,7 @@ def from_config(cls, config, custom_objects=None, safe_mode=None):
206207
and "class_name" in fn_config
207208
and fn_config["class_name"] == "__lambda__"
208209
):
209-
cls._raise_for_lambda_deserialization("function", safe_mode)
210+
cls._raise_for_lambda_deserialization(safe_mode)
210211
inner_config = fn_config["config"]
211212
fn = python_utils.func_load(
212213
inner_config["code"],

keras/src/layers/rnn/gru_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,41 @@ def test_pass_initial_state(self):
205205
output,
206206
)
207207

208+
def test_pass_return_state(self):
209+
sequence = np.arange(24).reshape((2, 4, 3)).astype("float32")
210+
initial_state = np.arange(4).reshape((2, 2)).astype("float32")
211+
212+
# Test with go_backwards=False
213+
layer = layers.GRU(
214+
2,
215+
kernel_initializer=initializers.Constant(0.01),
216+
recurrent_initializer=initializers.Constant(0.02),
217+
bias_initializer=initializers.Constant(0.03),
218+
return_state=True,
219+
)
220+
output, state = layer(sequence, initial_state=initial_state)
221+
self.assertAllClose(
222+
np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]),
223+
output,
224+
)
225+
self.assertAllClose(output, state)
226+
227+
# Test with go_backwards=True
228+
layer = layers.GRU(
229+
2,
230+
kernel_initializer=initializers.Constant(0.01),
231+
recurrent_initializer=initializers.Constant(0.02),
232+
bias_initializer=initializers.Constant(0.03),
233+
return_state=True,
234+
go_backwards=True,
235+
)
236+
output, state = layer(sequence, initial_state=initial_state)
237+
self.assertAllClose(
238+
np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]),
239+
output,
240+
)
241+
self.assertAllClose(output, state)
242+
208243
def test_masking(self):
209244
sequence = np.arange(24).reshape((2, 4, 3)).astype("float32")
210245
mask = np.array([[True, True, False, True], [True, False, False, True]])

keras/src/legacy/saving/legacy_h5_format.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from keras.src.legacy.saving import saving_options
1212
from keras.src.legacy.saving import saving_utils
1313
from keras.src.saving import object_registration
14+
from keras.src.saving import serialization_lib
1415
from keras.src.utils import io_utils
1516

1617
try:
@@ -72,7 +73,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
7273
f.close()
7374

7475

75-
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
76+
def load_model_from_hdf5(
77+
filepath, custom_objects=None, compile=True, safe_mode=True
78+
):
7679
"""Loads a model saved via `save_model_to_hdf5`.
7780
7881
Args:
@@ -128,7 +131,9 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
128131
model_config = model_config.decode("utf-8")
129132
model_config = json_utils.decode(model_config)
130133

131-
with saving_options.keras_option_scope(use_legacy_config=True):
134+
legacy_scope = saving_options.keras_option_scope(use_legacy_config=True)
135+
safe_mode_scope = serialization_lib.SafeModeScope(safe_mode)
136+
with legacy_scope, safe_mode_scope:
132137
model = saving_utils.model_from_config(
133138
model_config, custom_objects=custom_objects
134139
)

keras/src/legacy/saving/legacy_h5_format_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,13 @@ def test_saving_lambda(self):
158158

159159
temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
160160
legacy_h5_format.save_model_to_hdf5(model, temp_filepath)
161-
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
162161

162+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
163+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
164+
165+
loaded = legacy_h5_format.load_model_from_hdf5(
166+
temp_filepath, safe_mode=False
167+
)
163168
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
164169
self.assertAllClose(std, loaded.layers[1].arguments["std"])
165170

@@ -353,8 +358,13 @@ def test_saving_lambda(self):
353358

354359
temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
355360
tf_keras_model.save(temp_filepath)
356-
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
357361

362+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
363+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
364+
365+
loaded = legacy_h5_format.load_model_from_hdf5(
366+
temp_filepath, safe_mode=False
367+
)
358368
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
359369
self.assertAllClose(std, loaded.layers[1].arguments["std"])
360370

keras/src/legacy/saving/saving_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import threading
32

43
from absl import logging
@@ -81,10 +80,6 @@ def model_from_config(config, custom_objects=None):
8180
function_dict["config"]["closure"] = function_config[2]
8281
config["config"]["function"] = function_dict
8382

84-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
85-
# Replace keras refs with keras
86-
config = _find_replace_nested_dict(config, "keras.", "keras.")
87-
8883
return serialization.deserialize_keras_object(
8984
config,
9085
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
@@ -231,13 +226,6 @@ def _deserialize_metric(metric_config):
231226
return metrics_module.deserialize(metric_config)
232227

233228

234-
def _find_replace_nested_dict(config, find, replace):
235-
dict_str = json.dumps(config)
236-
dict_str = dict_str.replace(find, replace)
237-
config = json.loads(dict_str)
238-
return config
239-
240-
241229
def _resolve_compile_arguments_compat(obj, obj_config, module):
242230
"""Resolves backwards compatibility issues with training config arguments.
243231

keras/src/legacy/saving/serialization.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import contextlib
44
import inspect
5-
import json
65
import threading
76
import weakref
87

@@ -485,12 +484,6 @@ def deserialize(config, custom_objects=None):
485484
arg_spec = inspect.getfullargspec(cls.from_config)
486485
custom_objects = custom_objects or {}
487486

488-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
489-
# Replace keras refs with keras
490-
cls_config = _find_replace_nested_dict(
491-
cls_config, "keras.", "keras."
492-
)
493-
494487
if "custom_objects" in arg_spec.args:
495488
deserialized_obj = cls.from_config(
496489
cls_config,
@@ -565,10 +558,3 @@ def validate_config(config):
565558
def is_default(method):
566559
"""Check if a method is decorated with the `default` wrapper."""
567560
return getattr(method, "_is_default", False)
568-
569-
570-
def _find_replace_nested_dict(config, find, replace):
571-
dict_str = json.dumps(config)
572-
dict_str = dict_str.replace(find, replace)
573-
config = json.loads(dict_str)
574-
return config

keras/src/saving/saving_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
194194
)
195195
if str(filepath).endswith((".h5", ".hdf5")):
196196
return legacy_h5_format.load_model_from_hdf5(
197-
filepath, custom_objects=custom_objects, compile=compile
197+
filepath,
198+
custom_objects=custom_objects,
199+
compile=compile,
200+
safe_mode=safe_mode,
198201
)
199202
elif str(filepath).endswith(".keras"):
200203
raise ValueError(

keras/src/saving/saving_lib_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def test_safe_mode(self):
880880
]
881881
)
882882
model.save(temp_filepath)
883-
with self.assertRaisesRegex(ValueError, "Deserializing it is unsafe"):
883+
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
884884
model = saving_lib.load_model(temp_filepath)
885885
model = saving_lib.load_model(temp_filepath, safe_mode=False)
886886

keras/src/saving/serialization_lib.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -656,12 +656,12 @@ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
656656
if config["class_name"] == "__lambda__":
657657
if safe_mode:
658658
raise ValueError(
659-
"Requested the deserialization of a `lambda` object. "
660-
"This carries a potential risk of arbitrary code execution "
661-
"and thus it is disallowed by default. If you trust the "
662-
"source of the saved model, you can pass `safe_mode=False` to "
663-
"the loading function in order to allow `lambda` loading, "
664-
"or call `keras.config.enable_unsafe_deserialization()`."
659+
"Requested the deserialization of a Python lambda. This "
660+
"carries a potential risk of arbitrary code execution and thus "
661+
"it is disallowed by default. If you trust the source of the "
662+
"artifact, you can override this error by passing "
663+
"`safe_mode=False` to the loading function, or calling "
664+
"`keras.config.enable_unsafe_deserialization()."
665665
)
666666
return python_utils.func_load(inner_config["value"])
667667
if tf is not None and config["class_name"] == "__typespec__":

0 commit comments

Comments
 (0)