From ca4d25c7a025e50bf4a7be5ab5c8addb06d35a9a Mon Sep 17 00:00:00 2001 From: jesszzzz Date: Wed, 8 Jan 2025 10:49:58 -0500 Subject: [PATCH] Fix instantiate side effect modifying input config parent (#3005) * Deepcopy full config in instantiate to avoid unexpected side effects * Add test checking input config is unchanged * Create flag to revert deepcopy behavior in instantiate --- hydra/_internal/instantiate/_instantiate2.py | 45 ++++++++++++++-- news/3001.bugfix | 1 + tests/instantiate/test_instantiate.py | 55 +++++++++++++++++++- 3 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 news/3001.bugfix diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index 2f09ece868b..c5adb7e110d 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -145,7 +145,30 @@ def _resolve_target( return target -def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: +def _deep_copy_full_config(subconfig: Any) -> Any: + """Deep copy full config from root to leaf and return the copied subconfig""" + if not OmegaConf.is_config(subconfig): + return copy.deepcopy(subconfig) + + full_key = subconfig._get_full_key(None) + if full_key: + full_config_copy = copy.deepcopy(subconfig._get_root()) + if OmegaConf.is_list(subconfig._get_parent()): + # OmegaConf has a bug where _get_full_key doesn't add [] if the parent + # is a list, eg. instead of foo[0], it'll return foo0 + index = subconfig._key() + full_key = full_key[: -len(str(index))] + f"[{index}]" + return OmegaConf.select(full_config_copy, full_key) + else: + return copy.deepcopy(subconfig) + + +def instantiate( + config: Any, + *args: Any, + _skip_instantiate_full_deepcopy_: bool = False, + **kwargs: Any, +) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: @@ -168,6 +191,10 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: are converted to dicts / lists too. _partial_: If True, return functools.partial wrapped method or object False by default. Configure per target. + :param _skip_instantiate_full_deepcopy_: If True, deep copy just the input config instead + of full config before resolving omegaconf interpolations, which may + potentially modify the config's parent/sibling configs in place. + False by default. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present @@ -207,11 +234,15 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: if OmegaConf.is_dict(config): # Finalize config (convert targets to strings, merge with kwargs) - config_copy = copy.deepcopy(config) + # Create copy to avoid mutating original + if _skip_instantiate_full_deepcopy_: + config_copy = copy.deepcopy(config) + config_copy._set_parent(config._get_parent()) + else: + config_copy = _deep_copy_full_config(config) config_copy._set_flag( flags=["allow_objects", "struct", "readonly"], values=[True, False, False] ) - config_copy._set_parent(config._get_parent()) config = config_copy if kwargs: @@ -228,11 +259,15 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: ) elif OmegaConf.is_list(config): # Finalize config (convert targets to strings, merge with kwargs) - config_copy = copy.deepcopy(config) + # Create copy to avoid mutating original + if _skip_instantiate_full_deepcopy_: + config_copy = copy.deepcopy(config) + config_copy._set_parent(config._get_parent()) + else: + config_copy = _deep_copy_full_config(config) config_copy._set_flag( flags=["allow_objects", "struct", "readonly"], values=[True, False, False] ) - config_copy._set_parent(config._get_parent()) config = config_copy OmegaConf.resolve(config) diff --git a/news/3001.bugfix b/news/3001.bugfix new file mode 100644 index 00000000000..fa6df0c1e83 --- /dev/null +++ b/news/3001.bugfix @@ -0,0 +1 @@ +Fix unexpected resolution side-effect that caused modifications to the input config parent in `hydra.utils.instantiate` diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index 5e0654cecaf..10a67b1c14e 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -419,8 +419,10 @@ def test_class_instantiate( recursive: bool, ) -> Any: passthrough["_recursive_"] = recursive + original_config_str = str(config) obj = instantiate_func(config, **passthrough) assert partial_equal(obj, expected) + assert str(config) == original_config_str def test_partial_with_missing(instantiate_func: Any) -> Any: @@ -431,10 +433,12 @@ def test_partial_with_missing(instantiate_func: Any) -> Any: "b": 20, "c": 30, } + original_config_str = str(config) partial_obj = instantiate_func(config) assert partial_equal(partial_obj, partial(AClass, b=20, c=30)) obj = partial_obj(a=10) assert partial_equal(obj, AClass(a=10, b=20, c=30)) + assert str(config) == original_config_str def test_instantiate_with_missing(instantiate_func: Any) -> Any: @@ -468,6 +472,7 @@ def test_none_cases( ListConfig(None), ], } + original_config_str = str(cfg) ret = instantiate_func(cfg) assert ret.kwargs["none_dict"] is None assert ret.kwargs["none_list"] is None @@ -477,8 +482,11 @@ def test_none_cases( assert ret.kwargs["list"][0] == 10 assert ret.kwargs["list"][1] is None assert ret.kwargs["list"][2] is None + assert str(cfg) == original_config_str +@mark.parametrize("skip_deepcopy", [True, False]) +@mark.parametrize("convert_to_list", [True, False]) @mark.parametrize( "input_conf, passthrough, expected", [ @@ -537,6 +545,32 @@ def test_none_cases( 6, id="interpolation_from_recursive", ), + param( + { + "my_id": 5, + "node": { + "b": "${foo_b}", + }, + "foo_b": { + "unique_id": "${my_id}", + }, + }, + {}, + OmegaConf.create({"b": {"unique_id": 5}}), + id="interpolation_from_parent_with_interpolation", + ), + param( + { + "my_id": 5, + "node": "${foo_b}", + "foo_b": { + "unique_id": "${my_id}", + }, + }, + {}, + OmegaConf.create({"unique_id": 5}), + id="interpolation_from_parent_with_interpolation", + ), ], ) def test_interpolation_accessing_parent( @@ -544,15 +578,34 @@ def test_interpolation_accessing_parent( input_conf: Any, passthrough: Dict[str, Any], expected: Any, + convert_to_list: bool, + skip_deepcopy: bool, ) -> Any: + if convert_to_list: + input_conf = copy.deepcopy(input_conf) + input_conf["node"] = [input_conf["node"]] cfg_copy = OmegaConf.create(input_conf) input_conf = OmegaConf.create(input_conf) - obj = instantiate_func(input_conf.node, **passthrough) + original_config_str = str(input_conf) + if convert_to_list: + obj = instantiate_func( + input_conf.node[0], + _skip_instantiate_full_deepcopy_=skip_deepcopy, + **passthrough, + ) + else: + obj = instantiate_func( + input_conf.node, + _skip_instantiate_full_deepcopy_=skip_deepcopy, + **passthrough, + ) if isinstance(expected, partial): assert partial_equal(obj, expected) else: assert obj == expected assert input_conf == cfg_copy + if not skip_deepcopy: + assert str(input_conf) == original_config_str @mark.parametrize(