|
16 | 16 | from collections import defaultdict |
17 | 17 | from logging import getLogger |
18 | 18 | from pathlib import Path |
19 | | -from typing import Any, Callable, Mapping, Sequence, Type, overload |
| 19 | +from typing import Any, Callable, Hashable, Mapping, Sequence, Type, overload |
20 | 20 | from typing_extensions import TypeGuard |
21 | 21 | import warnings |
22 | 22 | from simple_parsing.helpers.subgroups import SubgroupKey |
@@ -647,32 +647,18 @@ def _resolve_subgroups( |
647 | 647 | # Sanity checks: |
648 | 648 | if subgroup_field.subgroup_default is dataclasses.MISSING: |
649 | 649 | assert argument_options["required"] |
650 | | - elif isinstance(argument_options["default"], dict): |
651 | | - # TODO: In this case here, the value of a nested subgroup in this default dict |
652 | | - # should also be used! |
653 | | - # BUG #276: The default here is a dict because it came from a config file. |
654 | | - # Here we want the subgroup field to have a 'str' default, because we just want |
655 | | - # to be able to choose between the subgroup names. |
656 | | - _default = argument_options["default"] |
657 | | - _default_key = _infer_subgroup_key_to_use_from_config( |
658 | | - default_in_config=_default, |
659 | | - # subgroup_default=subgroup_field.subgroup_default, |
660 | | - subgroup_choices=subgroup_field.subgroup_choices, |
661 | | - ) |
662 | | - # We'd like this field to (at least temporarily) have a different default |
663 | | - # value that is the subgroup key instead of the dictionary. |
664 | | - argument_options["default"] = _default_key |
665 | | - |
| 650 | + if "default" in argument_options: |
| 651 | + # todo: should ideally not set this in the first place... |
| 652 | + assert argument_options["default"] is dataclasses.MISSING |
| 653 | + argument_options.pop("default") |
| 654 | + assert "default" not in argument_options |
666 | 655 | else: |
667 | | - assert ( |
668 | | - argument_options["default"] is subgroup_field.subgroup_default |
669 | | - ), argument_options["default"] |
670 | | - assert not is_dataclass_instance(argument_options["default"]) |
671 | | - |
672 | | - # TODO: Do we really need to care about this "SUPPRESS" stuff here? |
673 | | - if argparse.SUPPRESS in subgroup_field.parent.defaults: |
674 | | - assert argument_options["default"] is argparse.SUPPRESS |
675 | | - argument_options["default"] = argparse.SUPPRESS |
| 656 | + assert "default" in argument_options |
| 657 | + assert argument_options["default"] == subgroup_field.default |
| 658 | + argument_options["default"] = _adjust_default_value_for_subgroup_field( |
| 659 | + subgroup_field=subgroup_field, |
| 660 | + subgroup_default=argument_options["default"], |
| 661 | + ) |
676 | 662 |
|
677 | 663 | logger.debug( |
678 | 664 | f"Adding subgroup argument: add_argument(*{flags} **{str(argument_options)})" |
@@ -1198,83 +1184,6 @@ def _create_dataclass_instance( |
1198 | 1184 | return constructor(**constructor_args) |
1199 | 1185 |
|
1200 | 1186 |
|
1201 | | -def _infer_subgroup_key_to_use_from_config( |
1202 | | - default_in_config: dict[str, Any], |
1203 | | - # subgroup_default: Hashable, |
1204 | | - subgroup_choices: Mapping[SubgroupKey, type[Dataclass] | functools.partial[Dataclass]], |
1205 | | -) -> SubgroupKey: |
1206 | | - config_default = default_in_config |
1207 | | - |
1208 | | - if SUBGROUP_KEY_FLAG in default_in_config: |
1209 | | - return default_in_config[SUBGROUP_KEY_FLAG] |
1210 | | - |
1211 | | - for subgroup_key, subgroup_value in subgroup_choices.items(): |
1212 | | - if default_in_config == subgroup_value: |
1213 | | - return subgroup_key |
1214 | | - |
1215 | | - assert ( |
1216 | | - DC_TYPE_KEY in config_default |
1217 | | - ), f"FIXME: assuming that the {DC_TYPE_KEY} is in the config dict." |
1218 | | - _default_type_name: str = config_default[DC_TYPE_KEY] |
1219 | | - |
1220 | | - if _has_values_of_type(subgroup_choices, type) and all( |
1221 | | - dataclasses.is_dataclass(subgroup_option) for subgroup_option in subgroup_choices.values() |
1222 | | - ): |
1223 | | - # Simpler case: All the subgroup options are dataclass types. We just get the key that |
1224 | | - # matches the type that was saved in the config dict. |
1225 | | - subgroup_keys_with_value_matching_config_default_type: list[SubgroupKey] = [ |
1226 | | - k |
1227 | | - for k, v in subgroup_choices.items() |
1228 | | - if (isinstance(v, type) and f"{v.__module__}.{v.__qualname__}" == _default_type_name) |
1229 | | - ] |
1230 | | - # NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})` |
1231 | | - assert len(subgroup_keys_with_value_matching_config_default_type) >= 1 |
1232 | | - return subgroup_keys_with_value_matching_config_default_type[0] |
1233 | | - |
1234 | | - # IDEA: Try to find the best subgroup key to use, based on the number of matching constructor |
1235 | | - # arguments between the default in the config and the defaults for each subgroup. |
1236 | | - constructor_args_in_each_subgroup = { |
1237 | | - key: _default_constructor_argument_values(subgroup_value) |
1238 | | - for key, subgroup_value in subgroup_choices.items() |
1239 | | - } |
1240 | | - n_matching_values = { |
1241 | | - k: _num_matching_values(config_default, constructor_args_in_subgroup_value) |
1242 | | - for k, constructor_args_in_subgroup_value in constructor_args_in_each_subgroup.items() |
1243 | | - } |
1244 | | - closest_subgroups_first = sorted( |
1245 | | - subgroup_choices.keys(), |
1246 | | - key=n_matching_values.__getitem__, |
1247 | | - reverse=True, |
1248 | | - ) |
1249 | | - warnings.warn( |
1250 | | - # TODO: Return the dataclass type instead, and be done with it! |
1251 | | - RuntimeWarning( |
1252 | | - f"TODO: The config file contains a default value for a subgroup that isn't in the " |
1253 | | - f"dict of subgroup options. Because of how subgroups are currently implemented, we " |
1254 | | - f"need to find the key in the subgroup choice dict ({subgroup_choices}) that most " |
1255 | | - f"closely matches the value {config_default}." |
1256 | | - f"The current implementation tries to use the dataclass type of this closest match " |
1257 | | - f"to parse the additional values from the command-line. " |
1258 | | - f"{default_in_config}. Consider adding the " |
1259 | | - f"{SUBGROUP_KEY_FLAG}: <key of the subgroup to use>" |
1260 | | - ) |
1261 | | - ) |
1262 | | - return closest_subgroups_first[0] |
1263 | | - return closest_subgroups_first[0] |
1264 | | - |
1265 | | - sorted( |
1266 | | - [k for k, v in subgroup_choices.items()], |
1267 | | - key=_num_matching_values, |
1268 | | - reversed=True, |
1269 | | - ) |
1270 | | - # _default_values = copy.deepcopy(config_default) |
1271 | | - # _default_values.pop(DC_TYPE_KEY) |
1272 | | - |
1273 | | - # default_constructor_args_for_each_subgroup = { |
1274 | | - # k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type) |
1275 | | - # } |
1276 | | - |
1277 | | - |
1278 | 1187 | def _has_values_of_type( |
1279 | 1188 | mapping: Mapping[K, Any], value_type: type[V] | tuple[type[V], ...] |
1280 | 1189 | ) -> TypeGuard[Mapping[K, V]]: |
@@ -1330,12 +1239,143 @@ def _default_constructor_argument_values( |
1330 | 1239 | return result |
1331 | 1240 |
|
1332 | 1241 |
|
1333 | | -def _num_matching_values(subgroup_default: dict[str, Any], subgroup_choice: dict[str, Any]) -> int: |
1334 | | - """Returns the number of matching entries in the subgroup dict w/ the default from the |
1335 | | - config.""" |
1336 | | - return sum( |
1337 | | - _num_matching_values(default_v, subgroup_choice[k]) |
1338 | | - if isinstance(subgroup_choice.get(k), dict) and isinstance(default_v, dict) |
1339 | | - else int(subgroup_choice.get(k) == default_v) |
1340 | | - for k, default_v in subgroup_default.items() |
| 1242 | +def _adjust_default_value_for_subgroup_field( |
| 1243 | + subgroup_field: FieldWrapper, subgroup_default: Any |
| 1244 | +) -> str | Hashable: |
| 1245 | + |
| 1246 | + if argparse.SUPPRESS in subgroup_field.parent.defaults: |
| 1247 | + assert subgroup_default is argparse.SUPPRESS |
| 1248 | + assert isinstance(subgroup_default, str) |
| 1249 | + return subgroup_default |
| 1250 | + |
| 1251 | + if isinstance(subgroup_default, dict): |
| 1252 | + default_from_config_file = subgroup_default |
| 1253 | + default_from_dataclass_field = subgroup_field.subgroup_default |
| 1254 | + |
| 1255 | + if SUBGROUP_KEY_FLAG in default_from_config_file: |
| 1256 | + _default_subgroup = default_from_config_file[SUBGROUP_KEY_FLAG] |
| 1257 | + logger.debug(f"Using subgroup key {_default_subgroup} as default (from config file)") |
| 1258 | + return _default_subgroup |
| 1259 | + |
| 1260 | + if DC_TYPE_KEY in default_from_config_file: |
| 1261 | + # The type of dataclass is specified in the config file. |
| 1262 | + # We can use that to figure out which subgroup to use. |
| 1263 | + default_dataclass_type_from_config = default_from_config_file[DC_TYPE_KEY] |
| 1264 | + if isinstance(default_dataclass_type_from_config, str): |
| 1265 | + from simple_parsing.helpers.serialization.serializable import _locate |
| 1266 | + |
| 1267 | + # Try to import the type of dataclass given its import path as a string in the |
| 1268 | + # config file. |
| 1269 | + default_dataclass_type_from_config = _locate(default_dataclass_type_from_config) |
| 1270 | + assert is_dataclass_type(default_dataclass_type_from_config) |
| 1271 | + |
| 1272 | + from simple_parsing.helpers.subgroups import _get_dataclass_type_from_callable |
| 1273 | + |
| 1274 | + subgroup_choices_with_matching_type: dict[ |
| 1275 | + Hashable, Dataclass | Callable[[], Dataclass] |
| 1276 | + ] = { |
| 1277 | + subgroup_key: subgroup_value |
| 1278 | + for subgroup_key, subgroup_value in subgroup_field.subgroup_choices.items() |
| 1279 | + if is_dataclass_type(subgroup_value) |
| 1280 | + and subgroup_value == default_dataclass_type_from_config |
| 1281 | + or is_dataclass_instance(subgroup_value) |
| 1282 | + and type(subgroup_value) == default_dataclass_type_from_config |
| 1283 | + or _get_dataclass_type_from_callable(subgroup_value) |
| 1284 | + == default_dataclass_type_from_config |
| 1285 | + } |
| 1286 | + logger.debug( |
| 1287 | + f"Subgroup choices that match the type in the config file: " |
| 1288 | + f"{subgroup_choices_with_matching_type}" |
| 1289 | + ) |
| 1290 | + |
| 1291 | + # IDEA: Try to find the best subgroup key to use, based on the number of matching |
| 1292 | + # constructor arguments between the default in the config and the defaults for each |
| 1293 | + # subgroup. |
| 1294 | + constructor_args_of_each_subgroup_val = { |
| 1295 | + key: ( |
| 1296 | + dataclasses.asdict(subgroup_value) |
| 1297 | + if is_dataclass_instance(subgroup_value) |
| 1298 | + # (the type should have been narrowed by the is_dataclass_instance typeguard, |
| 1299 | + # but somehow isn't...) |
| 1300 | + else _default_constructor_argument_values(subgroup_value) # type: ignore |
| 1301 | + ) |
| 1302 | + for key, subgroup_value in subgroup_choices_with_matching_type.items() |
| 1303 | + } |
| 1304 | + logger.debug( |
| 1305 | + f"Constructor arguments for each subgroup choice: " |
| 1306 | + f"{constructor_args_of_each_subgroup_val}" |
| 1307 | + ) |
| 1308 | + |
| 1309 | + def _num_overlapping_keys( |
| 1310 | + subgroup_default_in_config: PossiblyNestedDict[str, Any], |
| 1311 | + subgroup_option_from_field: PossiblyNestedDict[str, Any], |
| 1312 | + ) -> int: |
| 1313 | + """Returns the number of matching entries in the subgroup dict w/ the default from |
| 1314 | + the config.""" |
| 1315 | + overlap = 0 |
| 1316 | + for key, value in subgroup_default_in_config.items(): |
| 1317 | + if key in subgroup_option_from_field: |
| 1318 | + overlap += 1 |
| 1319 | + if isinstance(value, dict) and isinstance( |
| 1320 | + subgroup_option_from_field[key], dict |
| 1321 | + ): |
| 1322 | + overlap += _num_overlapping_keys( |
| 1323 | + value, subgroup_option_from_field[key] |
| 1324 | + ) |
| 1325 | + return overlap |
| 1326 | + |
| 1327 | + n_matching_values = { |
| 1328 | + k: _num_overlapping_keys(default_from_config_file, constructor_args_in_value) |
| 1329 | + for k, constructor_args_in_value in constructor_args_of_each_subgroup_val.items() |
| 1330 | + } |
| 1331 | + logger.debug( |
| 1332 | + f"Number of overlapping keys for each subgroup choice: {n_matching_values}" |
| 1333 | + ) |
| 1334 | + closest_subgroups_first = sorted( |
| 1335 | + subgroup_choices_with_matching_type.keys(), |
| 1336 | + key=n_matching_values.__getitem__, |
| 1337 | + reverse=True, |
| 1338 | + ) |
| 1339 | + closest_subgroup_key = closest_subgroups_first[0] |
| 1340 | + |
| 1341 | + warnings.warn( |
| 1342 | + RuntimeWarning( |
| 1343 | + f"The config file contains a default value for a subgroup field that isn't in " |
| 1344 | + f"the dict of subgroup options. " |
| 1345 | + f"Because of how subgroups are currently implemented, we need to find the key " |
| 1346 | + f"in the subgroup choice dict that most closely matches the value " |
| 1347 | + f"{default_from_config_file} in order to populate the default values for " |
| 1348 | + f"other fields.\n" |
| 1349 | + f"The default in the config file: {default_from_config_file}\n" |
| 1350 | + f"The default in the dataclass field: {default_from_dataclass_field}\n" |
| 1351 | + f"The subgroups dict: {subgroup_field.subgroup_choices}\n" |
| 1352 | + f"The current implementation tries to use the dataclass type of this closest " |
| 1353 | + f"match to parse the additional values from the command-line. " |
| 1354 | + f"Consider adding a {SUBGROUP_KEY_FLAG!r}: <key of the subgroup to use> item " |
| 1355 | + f"in the dict entry for that subgroup field in your config, to make it easier " |
| 1356 | + f"to tell directly which subgroup to use." |
| 1357 | + ) |
| 1358 | + ) |
| 1359 | + return closest_subgroup_key |
| 1360 | + |
| 1361 | + logger.debug( |
| 1362 | + f"Using subgroup key {default_from_dataclass_field} as default (from the dataclass " |
| 1363 | + f"field)" |
| 1364 | + ) |
| 1365 | + return default_from_dataclass_field |
| 1366 | + |
| 1367 | + if subgroup_default in subgroup_field.subgroup_choices.keys(): |
| 1368 | + return subgroup_default |
| 1369 | + |
| 1370 | + if subgroup_default in subgroup_field.subgroup_choices.values(): |
| 1371 | + matching_keys = [ |
| 1372 | + k for k, v in subgroup_field.subgroup_choices.items() if v == subgroup_default |
| 1373 | + ] |
| 1374 | + return matching_keys[0] |
| 1375 | + |
| 1376 | + raise RuntimeError( |
| 1377 | + f"Error: Unable to figure out what key matches the default value for the subgroup at " |
| 1378 | + f"{subgroup_field.dest}! (expected to either have the {SUBGROUP_KEY_FLAG!r} flag set, or " |
| 1379 | + f"one of the keys or values of the subgroups dict of that field: " |
| 1380 | + f"{subgroup_field.subgroup_choices})" |
1341 | 1381 | ) |
0 commit comments