forked from dhlab-epfl/dhSegment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
123 lines (107 loc) · 5.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import tensorflow as tf
# Tensorflow logging level
from logging import WARNING # import DEBUG, INFO, ERROR for more/less verbosity
tf.logging.set_verbosity(WARNING)
from dh_segment import estimator_fn, utils
from dh_segment.io import input
import json
from glob import glob
import numpy as np
try:
import better_exceptions
except ImportError:
print('/!\ W -- Not able to import package better_exceptions')
pass
from tqdm import trange
from sacred import Experiment
import pandas as pd
ex = Experiment('dhSegment_experiment')
@ex.config
def default_config():
train_data = None # Directory with training data
eval_data = None # Directory with validation data
model_output_dir = None # Directory to output tf model
restore_model = False # Set to true to continue training
classes_file = None # txt file with classes values (unused for REGRESSION)
gpu = '' # GPU to be used for training
prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL
pretrained_model_name = 'resnet50'
model_params = utils.ModelParams(pretrained_model_name=pretrained_model_name).to_dict() # Model parameters
training_params = utils.TrainingParams().to_dict() # Training parameters
if prediction_type == utils.PredictionType.CLASSIFICATION:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file(classes_file)
elif prediction_type == utils.PredictionType.REGRESSION:
model_params['n_classes'] = 1
elif prediction_type == utils.PredictionType.MULTILABEL:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file)
@ex.automain
def run(train_data, eval_data, model_output_dir, gpu, training_params, _config):
# Create output directory
if not os.path.isdir(model_output_dir):
os.makedirs(model_output_dir)
else:
assert _config.get('restore_model'), \
'{0} already exists, you cannot use it as output directory. ' \
'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir)
# Save config
with open(os.path.join(model_output_dir, 'config.json'), 'w') as f:
json.dump(_config, f, indent=4, sort_keys=True)
# Create export directory for saved models
saved_model_dir = os.path.join(model_output_dir, 'export')
if not os.path.isdir(saved_model_dir):
os.makedirs(saved_model_dir)
training_params = utils.TrainingParams.from_dict(training_params)
session_config = tf.ConfigProto()
session_config.gpu_options.visible_device_list = str(gpu)
session_config.gpu_options.per_process_gpu_memory_fraction = 0.9
estimator_config = tf.estimator.RunConfig().replace(session_config=session_config,
save_summary_steps=10,
keep_checkpoint_max=1)
estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir,
params=_config, config=estimator_config)
def get_dirs_or_files(input_data):
if os.path.isdir(input_data):
train_input, train_labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels')
# Check if training dir exists
if not os.path.isdir(train_input):
raise FileNotFoundError(train_input)
if not os.path.isdir(train_labels_input):
raise FileNotFoundError(train_labels_input)
elif os.path.isfile(train_data) and train_data.endswith('.csv'):
train_input = train_data
train_labels_input = None
else:
raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data))
return train_input, train_labels_input
train_input, train_labels_input = get_dirs_or_files(train_data)
if eval_data is not None:
eval_input, eval_labels_input = get_dirs_or_files(eval_data)
# Configure exporter
serving_input_fn = input.serving_input_filename(training_params.input_resized_size)
exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2)
for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'):
estimator.train(input.input_fn(train_input,
input_label_dir=train_labels_input,
num_epochs=training_params.evaluate_every_epoch,
batch_size=training_params.batch_size,
data_augmentation=training_params.data_augmentation,
make_patches=training_params.make_patches,
image_summaries=True,
params=_config,
num_threads=32))
if eval_data is not None:
eval_result = estimator.evaluate(input.input_fn(eval_input,
input_label_dir=eval_labels_input,
batch_size=1,
data_augmentation=False,
make_patches=False,
image_summaries=False,
params=_config,
num_threads=32))
else:
eval_result = None
exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result,
is_the_final_export=False)