Skip to content

Commit

Permalink
Add test checking input config is unchanged
Browse files Browse the repository at this point in the history
  • Loading branch information
jesszzzz committed Jan 6, 2025
1 parent 9dedc40 commit da36e03
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,10 @@ def test_class_instantiate(
recursive: bool,
) -> Any:
passthrough["_recursive_"] = recursive
original_config_str = str(config)
obj = instantiate_func(config, **passthrough)
assert partial_equal(obj, expected)
assert str(config) == original_config_str


def test_partial_with_missing(instantiate_func: Any) -> Any:
Expand All @@ -431,10 +433,12 @@ def test_partial_with_missing(instantiate_func: Any) -> Any:
"b": 20,
"c": 30,
}
original_config_str = str(config)
partial_obj = instantiate_func(config)
assert partial_equal(partial_obj, partial(AClass, b=20, c=30))
obj = partial_obj(a=10)
assert partial_equal(obj, AClass(a=10, b=20, c=30))
assert str(config) == original_config_str


def test_instantiate_with_missing(instantiate_func: Any) -> Any:
Expand Down Expand Up @@ -468,6 +472,7 @@ def test_none_cases(
ListConfig(None),
],
}
original_config_str = str(cfg)
ret = instantiate_func(cfg)
assert ret.kwargs["none_dict"] is None
assert ret.kwargs["none_list"] is None
Expand All @@ -477,8 +482,10 @@ def test_none_cases(
assert ret.kwargs["list"][0] == 10
assert ret.kwargs["list"][1] is None
assert ret.kwargs["list"][2] is None
assert str(cfg) == original_config_str


@mark.parametrize("convert_to_list", [True, False])
@mark.parametrize(
"input_conf, passthrough, expected",
[
Expand Down Expand Up @@ -537,22 +544,57 @@ def test_none_cases(
6,
id="interpolation_from_recursive",
),
param(
{
"my_id": 5,
"node": {
"b": "${foo_b}",
},
"foo_b": {
"unique_id": "${my_id}",
},
},
{},
OmegaConf.create({"b": {"unique_id": 5}}),
id="interpolation_from_parent_with_interpolation",
),
param(
{
"my_id": 5,
"node": "${foo_b}",
"foo_b": {
"unique_id": "${my_id}",
},
},
{},
OmegaConf.create({"unique_id": 5}),
id="interpolation_from_parent_with_interpolation",
),
],
)
def test_interpolation_accessing_parent(
instantiate_func: Any,
input_conf: Any,
passthrough: Dict[str, Any],
expected: Any,
convert_to_list: bool,
) -> Any:
if convert_to_list:
input_conf = copy.deepcopy(input_conf)
input_conf["node"] = [input_conf["node"]]
cfg_copy = OmegaConf.create(input_conf)
input_conf = OmegaConf.create(input_conf)
obj = instantiate_func(input_conf.node, **passthrough)
original_config_str = str(input_conf)
if convert_to_list:
obj = instantiate_func(input_conf.node[0], **passthrough)
else:
obj = instantiate_func(input_conf.node, **passthrough)
if isinstance(expected, partial):
assert partial_equal(obj, expected)
else:
assert obj == expected
assert input_conf == cfg_copy
assert str(input_conf) == original_config_str


@mark.parametrize(
Expand Down

0 comments on commit da36e03

Please sign in to comment.