-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathconfig_template.py
31 lines (21 loc) · 962 Bytes
/
config_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from configs.Validator import ObjectValidator
from configs.base_config import GeneralConfig, ConvertConfig, ConfigBase, \
DatasetConfigCoco, NetworkConfigOpenPose, TrainingConfigOpenPose
from configs.config_schema import cfg_schema
default_path = "/home/USERNAME/rtpose2d_data/"
dataset_dir = "/home/USERNAME/datasets/COCO"
class OpenPoseConfig(ConfigBase):
general = GeneralConfig()
convert = ConvertConfig()
network = NetworkConfigOpenPose(default_path)
train = TrainingConfigOpenPose(default_path)
dataset = DatasetConfigCoco(dataset_dir)
def __init__(self):
super().__init__()
self.network.model_state_file = "/media/disks/beta/models/openpose/itsc18_sim_full_c48.pth"
self.train.batch_size = 10
self.train.learning_rate = 0.001
cfg = OpenPoseConfig()
cfg_validator = ObjectValidator(cfg_schema)
if not cfg_validator.validate_object(cfg):
raise SystemError(str(cfg_validator.errors))