Skip to content

Commit f0800c4

Browse files
committed
Refactoring subgroup default
Signed-off-by: Fabrice Normandin <[email protected]>
1 parent f4b68fc commit f0800c4

File tree

3 files changed

+200
-114
lines changed

3 files changed

+200
-114
lines changed

simple_parsing/helpers/subgroups.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,16 @@ def subgroups(
5757
if not isinstance(default, Hashable):
5858
raise ValueError(
5959
"'default' can either be a key of the subgroups dict or a hashable (frozen) "
60-
"dataclass."
60+
"dataclass in the values of the subgroup dict."
6161
)
6262
if default not in subgroups.values():
6363
# NOTE: The reason we enforce this is perhaps artificial, but it's because the way we
6464
# implement subgroups requires us to know the key that is selected in the dict.
65-
raise ValueError(f"Default value {default} needs to be a value in the subgroups dict.")
65+
raise ValueError(
66+
f"When passing a dataclass instance as the `default` for the subgroups, it needs "
67+
f"to be a hashable value (e.g. frozen dataclass) in the subgroups dict. "
68+
f"Got {default}"
69+
)
6670
elif default is not MISSING and default not in subgroups.keys():
6771
raise ValueError("default must be a key in the subgroups dict!")
6872

simple_parsing/parsing.py

Lines changed: 151 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import defaultdict
1717
from logging import getLogger
1818
from pathlib import Path
19-
from typing import Any, Callable, Mapping, Sequence, Type, overload
19+
from typing import Any, Callable, Hashable, Mapping, Sequence, Type, overload
2020
from typing_extensions import TypeGuard
2121
import warnings
2222
from simple_parsing.helpers.subgroups import SubgroupKey
@@ -647,32 +647,18 @@ def _resolve_subgroups(
647647
# Sanity checks:
648648
if subgroup_field.subgroup_default is dataclasses.MISSING:
649649
assert argument_options["required"]
650-
elif isinstance(argument_options["default"], dict):
651-
# TODO: In this case here, the value of a nested subgroup in this default dict
652-
# should also be used!
653-
# BUG #276: The default here is a dict because it came from a config file.
654-
# Here we want the subgroup field to have a 'str' default, because we just want
655-
# to be able to choose between the subgroup names.
656-
_default = argument_options["default"]
657-
_default_key = _infer_subgroup_key_to_use_from_config(
658-
default_in_config=_default,
659-
# subgroup_default=subgroup_field.subgroup_default,
660-
subgroup_choices=subgroup_field.subgroup_choices,
661-
)
662-
# We'd like this field to (at least temporarily) have a different default
663-
# value that is the subgroup key instead of the dictionary.
664-
argument_options["default"] = _default_key
665-
650+
if "default" in argument_options:
651+
# todo: should ideally not set this in the first place...
652+
assert argument_options["default"] is dataclasses.MISSING
653+
argument_options.pop("default")
654+
assert "default" not in argument_options
666655
else:
667-
assert (
668-
argument_options["default"] is subgroup_field.subgroup_default
669-
), argument_options["default"]
670-
assert not is_dataclass_instance(argument_options["default"])
671-
672-
# TODO: Do we really need to care about this "SUPPRESS" stuff here?
673-
if argparse.SUPPRESS in subgroup_field.parent.defaults:
674-
assert argument_options["default"] is argparse.SUPPRESS
675-
argument_options["default"] = argparse.SUPPRESS
656+
assert "default" in argument_options
657+
assert argument_options["default"] == subgroup_field.default
658+
argument_options["default"] = _adjust_default_value_for_subgroup_field(
659+
subgroup_field=subgroup_field,
660+
subgroup_default=argument_options["default"],
661+
)
676662

677663
logger.debug(
678664
f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})"
@@ -1198,83 +1184,6 @@ def _create_dataclass_instance(
11981184
return constructor(**constructor_args)
11991185

12001186

1201-
def _infer_subgroup_key_to_use_from_config(
1202-
default_in_config: dict[str, Any],
1203-
# subgroup_default: Hashable,
1204-
subgroup_choices: Mapping[SubgroupKey, type[Dataclass] | functools.partial[Dataclass]],
1205-
) -> SubgroupKey:
1206-
config_default = default_in_config
1207-
1208-
if SUBGROUP_KEY_FLAG in default_in_config:
1209-
return default_in_config[SUBGROUP_KEY_FLAG]
1210-
1211-
for subgroup_key, subgroup_value in subgroup_choices.items():
1212-
if default_in_config == subgroup_value:
1213-
return subgroup_key
1214-
1215-
assert (
1216-
DC_TYPE_KEY in config_default
1217-
), f"FIXME: assuming that the {DC_TYPE_KEY} is in the config dict."
1218-
_default_type_name: str = config_default[DC_TYPE_KEY]
1219-
1220-
if _has_values_of_type(subgroup_choices, type) and all(
1221-
dataclasses.is_dataclass(subgroup_option) for subgroup_option in subgroup_choices.values()
1222-
):
1223-
# Simpler case: All the subgroup options are dataclass types. We just get the key that
1224-
# matches the type that was saved in the config dict.
1225-
subgroup_keys_with_value_matching_config_default_type: list[SubgroupKey] = [
1226-
k
1227-
for k, v in subgroup_choices.items()
1228-
if (isinstance(v, type) and f"{v.__module__}.{v.__qualname__}" == _default_type_name)
1229-
]
1230-
# NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})`
1231-
assert len(subgroup_keys_with_value_matching_config_default_type) >= 1
1232-
return subgroup_keys_with_value_matching_config_default_type[0]
1233-
1234-
# IDEA: Try to find the best subgroup key to use, based on the number of matching constructor
1235-
# arguments between the default in the config and the defaults for each subgroup.
1236-
constructor_args_in_each_subgroup = {
1237-
key: _default_constructor_argument_values(subgroup_value)
1238-
for key, subgroup_value in subgroup_choices.items()
1239-
}
1240-
n_matching_values = {
1241-
k: _num_matching_values(config_default, constructor_args_in_subgroup_value)
1242-
for k, constructor_args_in_subgroup_value in constructor_args_in_each_subgroup.items()
1243-
}
1244-
closest_subgroups_first = sorted(
1245-
subgroup_choices.keys(),
1246-
key=n_matching_values.__getitem__,
1247-
reverse=True,
1248-
)
1249-
warnings.warn(
1250-
# TODO: Return the dataclass type instead, and be done with it!
1251-
RuntimeWarning(
1252-
f"TODO: The config file contains a default value for a subgroup that isn't in the "
1253-
f"dict of subgroup options. Because of how subgroups are currently implemented, we "
1254-
f"need to find the key in the subgroup choice dict ({subgroup_choices}) that most "
1255-
f"closely matches the value {config_default}."
1256-
f"The current implementation tries to use the dataclass type of this closest match "
1257-
f"to parse the additional values from the command-line. "
1258-
f"{default_in_config}. Consider adding the "
1259-
f"{SUBGROUP_KEY_FLAG}: <key of the subgroup to use>"
1260-
)
1261-
)
1262-
return closest_subgroups_first[0]
1263-
return closest_subgroups_first[0]
1264-
1265-
sorted(
1266-
[k for k, v in subgroup_choices.items()],
1267-
key=_num_matching_values,
1268-
reversed=True,
1269-
)
1270-
# _default_values = copy.deepcopy(config_default)
1271-
# _default_values.pop(DC_TYPE_KEY)
1272-
1273-
# default_constructor_args_for_each_subgroup = {
1274-
# k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type)
1275-
# }
1276-
1277-
12781187
def _has_values_of_type(
12791188
mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...]
12801189
) -> TypeGuard[Mapping[K, V]]:
@@ -1330,12 +1239,143 @@ def _default_constructor_argument_values(
13301239
return result
13311240

13321241

1333-
def _num_matching_values(subgroup_default: dict[str, Any], subgroup_choice: dict[str, Any]) -> int:
1334-
"""Returns the number of matching entries in the subgroup dict w/ the default from the
1335-
config."""
1336-
return sum(
1337-
_num_matching_values(default_v, subgroup_choice[k])
1338-
if isinstance(subgroup_choice.get(k), dict) and isinstance(default_v, dict)
1339-
else int(subgroup_choice.get(k) == default_v)
1340-
for k, default_v in subgroup_default.items()
1242+
def _adjust_default_value_for_subgroup_field(
1243+
subgroup_field: FieldWrapper, subgroup_default: Any
1244+
) -> str | Hashable:
1245+
1246+
if argparse.SUPPRESS in subgroup_field.parent.defaults:
1247+
assert subgroup_default is argparse.SUPPRESS
1248+
assert isinstance(subgroup_default, str)
1249+
return subgroup_default
1250+
1251+
if isinstance(subgroup_default, dict):
1252+
default_from_config_file = subgroup_default
1253+
default_from_dataclass_field = subgroup_field.subgroup_default
1254+
1255+
if SUBGROUP_KEY_FLAG in default_from_config_file:
1256+
_default_subgroup = default_from_config_file[SUBGROUP_KEY_FLAG]
1257+
logger.debug(f"Using subgroup key {_default_subgroup} as default (from config file)")
1258+
return _default_subgroup
1259+
1260+
if DC_TYPE_KEY in default_from_config_file:
1261+
# The type of dataclass is specified in the config file.
1262+
# We can use that to figure out which subgroup to use.
1263+
default_dataclass_type_from_config = default_from_config_file[DC_TYPE_KEY]
1264+
if isinstance(default_dataclass_type_from_config, str):
1265+
from simple_parsing.helpers.serialization.serializable import _locate
1266+
1267+
# Try to import the type of dataclass given its import path as a string in the
1268+
# config file.
1269+
default_dataclass_type_from_config = _locate(default_dataclass_type_from_config)
1270+
assert is_dataclass_type(default_dataclass_type_from_config)
1271+
1272+
from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable
1273+
1274+
subgroup_choices_with_matching_type: dict[
1275+
Hashable, Dataclass | Callable[[], Dataclass]
1276+
] = {
1277+
subgroup_key: subgroup_value
1278+
for subgroup_key, subgroup_value in subgroup_field.subgroup_choices.items()
1279+
if is_dataclass_type(subgroup_value)
1280+
and subgroup_value == default_dataclass_type_from_config
1281+
or is_dataclass_instance(subgroup_value)
1282+
and type(subgroup_value) == default_dataclass_type_from_config
1283+
or _get_dataclass_type_from_callable(subgroup_value)
1284+
== default_dataclass_type_from_config
1285+
}
1286+
logger.debug(
1287+
f"Subgroup choices that match the type in the config file: "
1288+
f"{subgroup_choices_with_matching_type}"
1289+
)
1290+
1291+
# IDEA: Try to find the best subgroup key to use, based on the number of matching
1292+
# constructor arguments between the default in the config and the defaults for each
1293+
# subgroup.
1294+
constructor_args_of_each_subgroup_val = {
1295+
key: (
1296+
dataclasses.asdict(subgroup_value)
1297+
if is_dataclass_instance(subgroup_value)
1298+
# (the type should have been narrowed by the is_dataclass_instance typeguard,
1299+
# but somehow isn't...)
1300+
else _default_constructor_argument_values(subgroup_value) # type: ignore
1301+
)
1302+
for key, subgroup_value in subgroup_choices_with_matching_type.items()
1303+
}
1304+
logger.debug(
1305+
f"Constructor arguments for each subgroup choice: "
1306+
f"{constructor_args_of_each_subgroup_val}"
1307+
)
1308+
1309+
def _num_overlapping_keys(
1310+
subgroup_default_in_config: PossiblyNestedDict[str, Any],
1311+
subgroup_option_from_field: PossiblyNestedDict[str, Any],
1312+
) -> int:
1313+
"""Returns the number of matching entries in the subgroup dict w/ the default from
1314+
the config."""
1315+
overlap = 0
1316+
for key, value in subgroup_default_in_config.items():
1317+
if key in subgroup_option_from_field:
1318+
overlap += 1
1319+
if isinstance(value, dict) and isinstance(
1320+
subgroup_option_from_field[key], dict
1321+
):
1322+
overlap += _num_overlapping_keys(
1323+
value, subgroup_option_from_field[key]
1324+
)
1325+
return overlap
1326+
1327+
n_matching_values = {
1328+
k: _num_overlapping_keys(default_from_config_file, constructor_args_in_value)
1329+
for k, constructor_args_in_value in constructor_args_of_each_subgroup_val.items()
1330+
}
1331+
logger.debug(
1332+
f"Number of overlapping keys for each subgroup choice: {n_matching_values}"
1333+
)
1334+
closest_subgroups_first = sorted(
1335+
subgroup_choices_with_matching_type.keys(),
1336+
key=n_matching_values.__getitem__,
1337+
reverse=True,
1338+
)
1339+
closest_subgroup_key = closest_subgroups_first[0]
1340+
1341+
warnings.warn(
1342+
RuntimeWarning(
1343+
f"The config file contains a default value for a subgroup field that isn't in "
1344+
f"the dict of subgroup options. "
1345+
f"Because of how subgroups are currently implemented, we need to find the key "
1346+
f"in the subgroup choice dict that most closely matches the value "
1347+
f"{default_from_config_file} in order to populate the default values for "
1348+
f"other fields.\n"
1349+
f"The default in the config file: {default_from_config_file}\n"
1350+
f"The default in the dataclass field: {default_from_dataclass_field}\n"
1351+
f"The subgroups dict: {subgroup_field.subgroup_choices}\n"
1352+
f"The current implementation tries to use the dataclass type of this closest "
1353+
f"match to parse the additional values from the command-line. "
1354+
f"Consider adding a {SUBGROUP_KEY_FLAG!r}: <key of the subgroup to use> item "
1355+
f"in the dict entry for that subgroup field in your config, to make it easier "
1356+
f"to tell directly which subgroup to use."
1357+
)
1358+
)
1359+
return closest_subgroup_key
1360+
1361+
logger.debug(
1362+
f"Using subgroup key {default_from_dataclass_field} as default (from the dataclass "
1363+
f"field)"
1364+
)
1365+
return default_from_dataclass_field
1366+
1367+
if subgroup_default in subgroup_field.subgroup_choices.keys():
1368+
return subgroup_default
1369+
1370+
if subgroup_default in subgroup_field.subgroup_choices.values():
1371+
matching_keys = [
1372+
k for k, v in subgroup_field.subgroup_choices.items() if v == subgroup_default
1373+
]
1374+
return matching_keys[0]
1375+
1376+
raise RuntimeError(
1377+
f"Error: Unable to figure out what key matches the default value for the subgroup at "
1378+
f"{subgroup_field.dest}! (expected to either have the {SUBGROUP_KEY_FLAG!r} flag set, or "
1379+
f"one of the keys or values of the subgroups dict of that field: "
1380+
f"{subgroup_field.subgroup_choices})"
13411381
)

test/test_subgroups.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dataclasses
44
import functools
55
import inspect
6+
import json
67
import shlex
78
import sys
89
from dataclasses import dataclass, field
@@ -970,6 +971,7 @@ class A1OrA2:
970971
(A1OrA2(a=A1()), "--a=a2", A1OrA2(a=A2())),
971972
(A1OrA2(a=also_a2_default()), "", A1OrA2(a=also_a2_default())),
972973
],
974+
ids=repr,
973975
)
974976
@pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"])
975977
def test_parse_with_config_file_with_different_subgroup(
@@ -983,8 +985,8 @@ def test_parse_with_config_file_with_different_subgroup(
983985
# I think I was trying to reproduce the issue from #276
984986

985987
config_path = (tmp_path / "bob").with_suffix(filetype)
986-
987988
save(value_in_config, config_path, save_dc_types=True)
989+
988990
assert parse(A1OrA2, config_path=config_path, args=args) == expected
989991

990992

@@ -1000,3 +1002,43 @@ def test_roundtrip(value: Dataclass):
10001002
https://github.com/lebrice/SimpleParsing/pull/284#issuecomment-1783490388."""
10011003
assert from_dict(type(value), to_dict(value)) == value
10021004
assert to_dict(from_dict(type(value), to_dict(value))) == to_dict(value)
1005+
1006+
1007+
@dataclass
1008+
class AorB:
1009+
a_or_b: A | B = subgroups(
1010+
{"a": A, "b": B, "also_a": functools.partial(A, a=1.23)}, default="a"
1011+
)
1012+
1013+
1014+
def test_saved_with_key_as_default(tmp_path: Path):
1015+
"""Test to try to reproduce
1016+
https://github.com/lebrice/SimpleParsing/pull/284#discussion_r1434421587
1017+
"""
1018+
1019+
config_path = tmp_path / "config.json"
1020+
config_path.write_text(json.dumps({"a_or_b": "b"}))
1021+
1022+
assert parse(AorB, args="") == AorB(a_or_b=A())
1023+
assert parse(AorB, config_path=config_path, args="") == AorB(a_or_b=B())
1024+
assert parse(AorB, config_path=config_path, args="--a_or_b=a") == AorB(a_or_b=A())
1025+
1026+
1027+
def test_saved_with_custom_dict_as_default(tmp_path: Path):
1028+
"""Test when a customized dict is set in the config for a subgroups field.
1029+
1030+
We expect to have a warning but for things to work.
1031+
"""
1032+
1033+
config_path = tmp_path / "config.json"
1034+
config_path.write_text(json.dumps({"a_or_b": {"b": "somefoo"}}))
1035+
assert parse(AorB, args="") == AorB(a_or_b=A())
1036+
1037+
with pytest.raises(TypeError):
1038+
# Default is 'a', so we should get a TypeError because b="somefoo" is passed to `A`.
1039+
assert parse(AorB, config_path=config_path, args="")
1040+
1041+
with pytest.warns(RuntimeWarning):
1042+
assert parse(AorB, config_path=config_path, args="--b=bobo") == AorB(a_or_b=B(b="bobo"))
1043+
1044+
assert parse(AorB, config_path=config_path, args="--a_or_b=a") == AorB(a_or_b=A())

0 commit comments

Comments
 (0)