Skip to content

Commit

Permalink
Fix instantiate side effect modifying input config parent (#3005)
Browse files Browse the repository at this point in the history
* Deepcopy full config in instantiate to avoid unexpected side effects

* Add test checking input config is unchanged

* Create flag to revert deepcopy behavior in instantiate
  • Loading branch information
jesszzzz authored Jan 8, 2025
1 parent 0ede840 commit ca4d25c
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
45 changes: 40 additions & 5 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,30 @@ def _resolve_target(
return target


def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
def _deep_copy_full_config(subconfig: Any) -> Any:
"""Deep copy full config from root to leaf and return the copied subconfig"""
if not OmegaConf.is_config(subconfig):
return copy.deepcopy(subconfig)

full_key = subconfig._get_full_key(None)
if full_key:
full_config_copy = copy.deepcopy(subconfig._get_root())
if OmegaConf.is_list(subconfig._get_parent()):
# OmegaConf has a bug where _get_full_key doesn't add [] if the parent
# is a list, eg. instead of foo[0], it'll return foo0
index = subconfig._key()
full_key = full_key[: -len(str(index))] + f"[{index}]"
return OmegaConf.select(full_config_copy, full_key)
else:
return copy.deepcopy(subconfig)


def instantiate(
config: Any,
*args: Any,
_skip_instantiate_full_deepcopy_: bool = False,
**kwargs: Any,
) -> Any:
"""
:param config: An config object describing what to call and what params to use.
In addition to the parameters, the config must contain:
Expand All @@ -168,6 +191,10 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
are converted to dicts / lists too.
_partial_: If True, return functools.partial wrapped method or object
False by default. Configure per target.
:param _skip_instantiate_full_deepcopy_: If True, deep copy just the input config instead
of full config before resolving omegaconf interpolations, which may
potentially modify the config's parent/sibling configs in place.
False by default.
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
Expand Down Expand Up @@ -207,11 +234,15 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:

if OmegaConf.is_dict(config):
# Finalize config (convert targets to strings, merge with kwargs)
config_copy = copy.deepcopy(config)
# Create copy to avoid mutating original
if _skip_instantiate_full_deepcopy_:
config_copy = copy.deepcopy(config)
config_copy._set_parent(config._get_parent())
else:
config_copy = _deep_copy_full_config(config)
config_copy._set_flag(
flags=["allow_objects", "struct", "readonly"], values=[True, False, False]
)
config_copy._set_parent(config._get_parent())
config = config_copy

if kwargs:
Expand All @@ -228,11 +259,15 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
)
elif OmegaConf.is_list(config):
# Finalize config (convert targets to strings, merge with kwargs)
config_copy = copy.deepcopy(config)
# Create copy to avoid mutating original
if _skip_instantiate_full_deepcopy_:
config_copy = copy.deepcopy(config)
config_copy._set_parent(config._get_parent())
else:
config_copy = _deep_copy_full_config(config)
config_copy._set_flag(
flags=["allow_objects", "struct", "readonly"], values=[True, False, False]
)
config_copy._set_parent(config._get_parent())
config = config_copy

OmegaConf.resolve(config)
Expand Down
1 change: 1 addition & 0 deletions news/3001.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix unexpected resolution side-effect that caused modifications to the input config parent in `hydra.utils.instantiate`
55 changes: 54 additions & 1 deletion tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,10 @@ def test_class_instantiate(
recursive: bool,
) -> Any:
passthrough["_recursive_"] = recursive
original_config_str = str(config)
obj = instantiate_func(config, **passthrough)
assert partial_equal(obj, expected)
assert str(config) == original_config_str


def test_partial_with_missing(instantiate_func: Any) -> Any:
Expand All @@ -431,10 +433,12 @@ def test_partial_with_missing(instantiate_func: Any) -> Any:
"b": 20,
"c": 30,
}
original_config_str = str(config)
partial_obj = instantiate_func(config)
assert partial_equal(partial_obj, partial(AClass, b=20, c=30))
obj = partial_obj(a=10)
assert partial_equal(obj, AClass(a=10, b=20, c=30))
assert str(config) == original_config_str


def test_instantiate_with_missing(instantiate_func: Any) -> Any:
Expand Down Expand Up @@ -468,6 +472,7 @@ def test_none_cases(
ListConfig(None),
],
}
original_config_str = str(cfg)
ret = instantiate_func(cfg)
assert ret.kwargs["none_dict"] is None
assert ret.kwargs["none_list"] is None
Expand All @@ -477,8 +482,11 @@ def test_none_cases(
assert ret.kwargs["list"][0] == 10
assert ret.kwargs["list"][1] is None
assert ret.kwargs["list"][2] is None
assert str(cfg) == original_config_str


@mark.parametrize("skip_deepcopy", [True, False])
@mark.parametrize("convert_to_list", [True, False])
@mark.parametrize(
"input_conf, passthrough, expected",
[
Expand Down Expand Up @@ -537,22 +545,67 @@ def test_none_cases(
6,
id="interpolation_from_recursive",
),
param(
{
"my_id": 5,
"node": {
"b": "${foo_b}",
},
"foo_b": {
"unique_id": "${my_id}",
},
},
{},
OmegaConf.create({"b": {"unique_id": 5}}),
id="interpolation_from_parent_with_interpolation",
),
param(
{
"my_id": 5,
"node": "${foo_b}",
"foo_b": {
"unique_id": "${my_id}",
},
},
{},
OmegaConf.create({"unique_id": 5}),
id="interpolation_from_parent_with_interpolation",
),
],
)
def test_interpolation_accessing_parent(
instantiate_func: Any,
input_conf: Any,
passthrough: Dict[str, Any],
expected: Any,
convert_to_list: bool,
skip_deepcopy: bool,
) -> Any:
if convert_to_list:
input_conf = copy.deepcopy(input_conf)
input_conf["node"] = [input_conf["node"]]
cfg_copy = OmegaConf.create(input_conf)
input_conf = OmegaConf.create(input_conf)
obj = instantiate_func(input_conf.node, **passthrough)
original_config_str = str(input_conf)
if convert_to_list:
obj = instantiate_func(
input_conf.node[0],
_skip_instantiate_full_deepcopy_=skip_deepcopy,
**passthrough,
)
else:
obj = instantiate_func(
input_conf.node,
_skip_instantiate_full_deepcopy_=skip_deepcopy,
**passthrough,
)
if isinstance(expected, partial):
assert partial_equal(obj, expected)
else:
assert obj == expected
assert input_conf == cfg_copy
if not skip_deepcopy:
assert str(input_conf) == original_config_str


@mark.parametrize(
Expand Down

0 comments on commit ca4d25c

Please sign in to comment.