77import  argparse 
88import  dataclasses 
99import  functools 
10+ import  inspect 
1011import  itertools 
1112import  shlex 
1213import  sys 
1516from  collections  import  defaultdict 
1617from  logging  import  getLogger 
1718from  pathlib  import  Path 
18- from  typing  import  Any , Callable , Sequence , Type , overload 
19- 
19+ from  typing  import  Any , Callable , Mapping , Sequence , Type , overload 
20+ from  typing_extensions  import  TypeGuard 
21+ import  warnings 
2022from  simple_parsing .helpers .subgroups  import  SubgroupKey 
23+ from  simple_parsing .replace  import  SUBGROUP_KEY_FLAG 
2124from  simple_parsing .wrappers .dataclass_wrapper  import  DataclassWrapperType 
2225
2326from  . import  utils 
2427from  .conflicts  import  ConflictResolution , ConflictResolver 
2528from  .help_formatter  import  SimpleHelpFormatter 
26- from  .helpers .serialization .serializable  import  read_file 
29+ from  .helpers .serialization .serializable  import  DC_TYPE_KEY ,  read_file 
2730from  .utils  import  (
31+     K ,
32+     V ,
2833    Dataclass ,
2934    DataclassT ,
35+     PossiblyNestedDict ,
3036    dict_union ,
3137    is_dataclass_instance ,
3238    is_dataclass_type ,
@@ -593,7 +599,7 @@ def _resolve_subgroups(
593599
594600        This modifies the wrappers in-place, by possibly adding children to the wrappers in the 
595601        list. 
596-         Returns a list with the modified wrappers. 
602+         Returns a list with the (now  modified)  wrappers. 
597603
598604        Each round does the following: 
599605        1.  Resolve any conflicts using the conflict resolver. Two subgroups at the same nesting 
@@ -618,13 +624,7 @@ def _resolve_subgroups(
618624        # times. 
619625        subgroup_choice_parser  =  argparse .ArgumentParser (
620626            add_help = False ,
621-             # conflict_resolution=self.conflict_resolution, 
622-             # add_option_string_dash_variants=self.add_option_string_dash_variants, 
623-             # argument_generation_mode=self.argument_generation_mode, 
624-             # nested_mode=self.nested_mode, 
625627            formatter_class = self .formatter_class ,
626-             # add_config_path_arg=self.add_config_path_arg, 
627-             # config_path=self.config_path, 
628628            # NOTE: We disallow abbreviations for subgroups for now. This prevents potential issues 
629629            # for example if you have —a_or_b and A has a field —a then it will error out if you 
630630            # pass —a=1 because 1 isn’t a choice for the a_or_b argument (because --a matches it 
@@ -644,10 +644,27 @@ def _resolve_subgroups(
644644                flags  =  subgroup_field .option_strings 
645645                argument_options  =  subgroup_field .arg_options 
646646
647+                 # Sanity checks: 
647648                if  subgroup_field .subgroup_default  is  dataclasses .MISSING :
648649                    assert  argument_options ["required" ]
650+                 elif  isinstance (argument_options ["default" ], dict ):
651+                     # BUG #276: The default here is a dict because it came from a config file. 
652+                     # Here we want the subgroup field to have a 'str' default, because we just want 
653+                     # to be able to choose between the subgroup names. 
654+                     _default  =  argument_options ["default" ]
655+                     _default_key  =  _infer_subgroup_key_to_use_from_config (
656+                         default_in_config = _default ,
657+                         # subgroup_default=subgroup_field.subgroup_default, 
658+                         subgroup_choices = subgroup_field .subgroup_choices ,
659+                     )
660+                     # We'd like this field to (at least temporarily) have a different default 
661+                     # value that is the subgroup key instead of the dictionary. 
662+                     argument_options ["default" ] =  _default_key 
663+ 
649664                else :
650-                     assert  argument_options ["default" ] is  subgroup_field .subgroup_default 
665+                     assert  (
666+                         argument_options ["default" ] is  subgroup_field .subgroup_default 
667+                     ), argument_options ["default" ]
651668                    assert  not  is_dataclass_instance (argument_options ["default" ])
652669
653670                # TODO: Do we really need to care about this "SUPPRESS" stuff here? 
@@ -1177,3 +1194,146 @@ def _create_dataclass_instance(
11771194            return  None 
11781195    logger .debug (f"Calling constructor: { constructor } { constructor_args }  )
11791196    return  constructor (** constructor_args )
1197+ 
1198+ 
1199+ def  _has_values_of_type (
1200+     mapping : Mapping [K , Any ], value_type : type [V ] |  tuple [type [V ], ...]
1201+ ) ->  TypeGuard [Mapping [K , V ]]:
1202+     # Utility functions used to narrow the type of dictionaries. 
1203+     return  all (isinstance (v , value_type ) for  v  in  mapping .values ())
1204+ 
1205+ 
1206+ def  _has_keys_of_type (
1207+     mapping : Mapping [Any , V ], key_type : type [K ] |  tuple [type [K ], ...]
1208+ ) ->  TypeGuard [Mapping [K , V ]]:
1209+     # Utility functions used to narrow the type of dictionaries. 
1210+     return  all (isinstance (k , key_type ) for  k  in  mapping .keys ())
1211+ 
1212+ 
1213+ def  _has_items_of_type (
1214+     mapping : Mapping [Any , Any ],
1215+     item_type : tuple [type [K ] |  tuple [type [K ], ...], type [V ] |  tuple [type [V ], ...]],
1216+ ) ->  TypeGuard [Mapping [K , V ]]:
1217+     # Utility functions used to narrow the type of a dictionary or mapping. 
1218+     key_type , value_type  =  item_type 
1219+     return  _has_keys_of_type (mapping , key_type ) and  _has_values_of_type (mapping , value_type )
1220+ 
1221+ 
1222+ def  _infer_subgroup_key_to_use_from_config (
1223+     default_in_config : dict [str , Any ],
1224+     # subgroup_default: Hashable, 
1225+     subgroup_choices : Mapping [SubgroupKey , type [Dataclass ] |  functools .partial [Dataclass ]],
1226+ ) ->  SubgroupKey :
1227+     config_default  =  default_in_config 
1228+ 
1229+     if  SUBGROUP_KEY_FLAG  in  default_in_config :
1230+         return  default_in_config [SUBGROUP_KEY_FLAG ]
1231+ 
1232+     for  subgroup_key , subgroup_value  in  subgroup_choices .items ():
1233+         if  default_in_config  ==  subgroup_value :
1234+             return  subgroup_key 
1235+ 
1236+     assert  (
1237+         DC_TYPE_KEY  in  config_default 
1238+     ), f"FIXME: assuming that the { DC_TYPE_KEY }  
1239+     _default_type_name : str  =  config_default [DC_TYPE_KEY ]
1240+ 
1241+     if  _has_values_of_type (subgroup_choices , type ) and  all (
1242+         dataclasses .is_dataclass (subgroup_option ) for  subgroup_option  in  subgroup_choices .values ()
1243+     ):
1244+         # Simpler case: All the subgroup options are dataclass types. We just get the key that 
1245+         # matches the type that was saved in the config dict. 
1246+         subgroup_keys_with_value_matching_config_default_type : list [SubgroupKey ] =  [
1247+             k 
1248+             for  k , v  in  subgroup_choices .items ()
1249+             if  (isinstance (v , type ) and  f"{ v .__module__ } { v .__qualname__ }   ==  _default_type_name )
1250+         ]
1251+         # NOTE: There could be duplicates I guess? Something like `subgroups({"a": A, "aa": A})` 
1252+         assert  len (subgroup_keys_with_value_matching_config_default_type ) >=  1 
1253+         return  subgroup_keys_with_value_matching_config_default_type [0 ]
1254+ 
1255+     # IDEA: Try to find the best subgroup key to use, based on the number of matching constructor 
1256+     # arguments between the default in the config and the defaults for each subgroup. 
1257+     constructor_args_in_each_subgroup  =  {
1258+         key : _default_constructor_argument_values (subgroup_value )
1259+         for  key , subgroup_value  in  subgroup_choices .items ()
1260+     }
1261+     n_matching_values  =  {
1262+         k : _num_matching_values (config_default , constructor_args_in_subgroup_value )
1263+         for  k , constructor_args_in_subgroup_value  in  constructor_args_in_each_subgroup .items ()
1264+     }
1265+     closest_subgroups_first  =  sorted (
1266+         subgroup_choices .keys (),
1267+         key = n_matching_values .__getitem__ ,
1268+         reverse = True ,
1269+     )
1270+     warnings .warn (
1271+         # TODO: Return the dataclass type instead, and be done with it! 
1272+         RuntimeWarning (
1273+             f"TODO: The config file contains a default value for a subgroup that isn't in the " 
1274+             f"dict of subgroup options. Because of how subgroups are currently implemented, we " 
1275+             f"need to find the key in the subgroup choice dict ({ subgroup_choices }  
1276+             f"closely matches the value { config_default }  
1277+             f"The current implementation tries to use the dataclass type of this closest match " 
1278+             f"to parse the additional values from the command-line. " 
1279+             f"{ default_in_config }  
1280+             f"{ SUBGROUP_KEY_FLAG }  
1281+         )
1282+     )
1283+     return  closest_subgroups_first [0 ]
1284+     return  closest_subgroups_first [0 ]
1285+ 
1286+     sorted (
1287+         [k  for  k , v  in  subgroup_choices .items ()],
1288+         key = _num_matching_values ,
1289+         reversed = True ,
1290+     )
1291+     # _default_values = copy.deepcopy(config_default) 
1292+     # _default_values.pop(DC_TYPE_KEY) 
1293+ 
1294+     # default_constructor_args_for_each_subgroup = { 
1295+     #     k: _default_constructor_argument_values(dc_type) if dataclasses.is_dataclass(dc_type) 
1296+     # } 
1297+ 
1298+ 
1299+ def  _default_constructor_argument_values (
1300+     some_dataclass_type : type [Dataclass ] |  functools .partial [Dataclass ],
1301+ ) ->  PossiblyNestedDict [str , Any ]:
1302+     result  =  {}
1303+     if  isinstance (some_dataclass_type , functools .partial ) and  is_dataclass_type (
1304+         some_dataclass_type .func 
1305+     ):
1306+         constructor_arguments_from_classdef  =  _default_constructor_argument_values (
1307+             some_dataclass_type .func 
1308+         )
1309+         # TODO: will probably raise an error! 
1310+         constructor_arguments_from_partial  =  (
1311+             inspect .signature (some_dataclass_type .func )
1312+             .bind_partial (* some_dataclass_type .args , ** some_dataclass_type .keywords )
1313+             .arguments 
1314+         )
1315+         constructor_arguments_from_classdef .update (constructor_arguments_from_partial )
1316+         return  constructor_arguments_from_classdef 
1317+ 
1318+     assert  is_dataclass_type (some_dataclass_type )
1319+     for  field  in  dataclasses .fields (some_dataclass_type ):
1320+         key  =  field .name 
1321+         if  field .default  is  not dataclasses .MISSING :
1322+             result [key ] =  field .default 
1323+         elif  is_dataclass_type (field .type ) or  (
1324+             isinstance (field .default_factory , functools .partial )
1325+             and  dataclasses .is_dataclass (field .default_factory .func )
1326+         ):
1327+             result [key ] =  _default_constructor_argument_values (field .type )
1328+     return  result 
1329+ 
1330+ 
1331+ def  _num_matching_values (subgroup_default : dict [str , Any ], subgroup_choice : dict [str , Any ]) ->  int :
1332+     """Returns the number of matching entries in the subgroup dict w/ the default from the 
1333+     config.""" 
1334+     return  sum (
1335+         _num_matching_values (default_v , subgroup_choice [k ])
1336+         if  isinstance (subgroup_choice .get (k ), dict ) and  isinstance (default_v , dict )
1337+         else  int (subgroup_choice .get (k ) ==  default_v )
1338+         for  k , default_v  in  subgroup_default .items ()
1339+     )
0 commit comments