1
1
"""
2
2
Base model class as lowest building block for dynamic modules
3
3
"""
4
+
4
5
import inspect
5
6
from abc import ABC , abstractmethod
6
7
from functools import wraps
17
18
module_validations : Validations = {
18
19
"name" : ["str" , ("longer" , 2 )],
19
20
"print_prio" : [("in" , PRINT_PRIORITY )],
20
- "device" : ["str" , ("or" , (("in" , ["cuda" , "cpu" ]), ("instance " , torch .device )))],
21
+ "device" : [("or" , (("in" , ["cuda" , "cpu" ]), ("type " , torch .device )))],
21
22
"gpus" : ["optional" , lambda gpus : isinstance (gpus , list ) and all (isinstance (gpu , int ) for gpu in gpus )],
22
23
"num_workers" : ["optional" , "int" , ("gte" , 0 )],
23
24
"sp" : [("instance" , bool )],
@@ -97,7 +98,7 @@ class BaseModule(ABC):
97
98
"""
98
99
99
100
@enable_keyboard_interrupt
100
- def __init__ (self , config : Config , path : NodePath , validate_base : bool = False ):
101
+ def __init__ (self , config : Config , path : NodePath ):
101
102
self .config : Config = config
102
103
self .params : Config = get_sub_config (config , path )
103
104
self ._path : NodePath = path
@@ -106,16 +107,13 @@ def __init__(self, config: Config, path: NodePath, validate_base: bool = False):
106
107
if not self .config ["gpus" ]:
107
108
self .config ["gpus" ] = [- 1 ]
108
109
elif isinstance (self .config ["gpus" ], str ):
109
- self .config ["gpus" ] = (
110
- [int (i ) for i in self .config ["gpus" ].split ("," )] if torch .cuda .device_count () >= 1 else [- 1 ]
111
- )
110
+ self .config ["gpus" ] = [int (i ) for i in self .config ["gpus" ].split ("," )]
111
+
112
112
# set default value of num_workers
113
113
if not self .config ["num_workers" ]:
114
114
self .config ["num_workers" ] = 0
115
115
116
- # validate config when calling BaseModule class and flag is True
117
- if validate_base :
118
- self .validate_params (module_validations , "config" )
116
+ self .validate_params (module_validations , "config" )
119
117
120
118
def validate_params (self , validations : Validations , attrib_name : str = "params" ) -> None :
121
119
"""Given per key validations, validate this module's parameters.
@@ -173,7 +171,9 @@ def validate_params(self, validations: Validations, attrib_name: str = "params")
173
171
if param_name not in getattr (self , attrib_name ):
174
172
if "optional" in list_of_validations :
175
173
continue # value is optional and does not exist, skip validation
176
- raise InvalidParameterException (f"{ param_name } is expected to be in module { self .__class__ .__name__ } " )
174
+ raise InvalidParameterException (
175
+ f"'{ param_name } ' is expected to be in module '{ self .__class__ .__name__ } '"
176
+ )
177
177
178
178
# it is now safe to get the value
179
179
value = getattr (self , attrib_name )[param_name ]
@@ -202,13 +202,13 @@ def validate_params(self, validations: Validations, attrib_name: str = "params")
202
202
if validate_value (value = value , data = data , validation = validation_name ):
203
203
continue
204
204
raise InvalidParameterException (
205
- f"In module { self .__class__ .__name__ } , parameter { param_name } is not valid. "
206
- f"Value is { value } and is expected to have validation(s) { list_of_validations } ."
205
+ f"In module ' { self .__class__ .__name__ } ' , parameter ' { param_name } ' is not valid. "
206
+ f"Value is ' { value } ' and is expected to have validation(s) ' { list_of_validations } ' ."
207
207
)
208
208
# no other case was true
209
209
raise ValidationException (
210
- f"Validation is expected to be callable or tuple, but is { type (validation )} . "
211
- f"Current module: { self .__class__ .__name__ } , Parameter: { param_name } "
210
+ f"Validation is expected to be callable or tuple, but is ' { type (validation )} ' . "
211
+ f"Current module: ' { self .__class__ .__name__ } ' , Parameter: ' { param_name } ' "
212
212
)
213
213
214
214
@abstractmethod
0 commit comments