@@ -3,164 +3,21 @@ API - Distributed Training
3
3
4
4
(Alpha release - usage might change later)
5
5
6
- Helper sessions and methods to run a distributed training.
7
- Check this ` minst example <https://github.com/tensorlayer/tensorlayer/blob /master/example/tutorial_mnist_distributed.py >`_.
6
+ Helper API to run a distributed training.
7
+ Check these ` examples <https://github.com/tensorlayer/tensorlayer/tree /master/examples/distributed_training >`_.
8
8
9
9
.. automodule :: tensorlayer.distributed
10
10
11
11
.. autosummary ::
12
12
13
- TaskSpecDef
14
- TaskSpec
15
- DistributedSession
16
- StopAtTimeHook
17
- LoadCheckpoint
18
-
13
+ Trainer
19
14
20
15
Distributed training
21
16
--------------------
22
17
23
-
24
- TaskSpecDef
18
+ Trainer
25
19
^^^^^^^^^^^
26
20
27
- .. autofunction :: TaskSpecDef
28
-
29
- Create TaskSpecDef from environment variables
30
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31
-
32
- .. autofunction :: TaskSpec
33
-
34
- Distributed session object
35
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
36
-
37
- .. autofunction :: DistributedSession
38
-
39
- Data sharding
40
- ^^^^^^^^^^^^^^^^^^^^^^
41
-
42
- In some cases we want to shard the data among all the training servers and
43
- not use all the data in all servers. TensorFlow >=1.4 provides some helper classes
44
- to work with data that support data sharding: `Datasets <https://www.tensorflow.org/programmers_guide/datasets >`_
45
-
46
- It is important in sharding that the shuffle or any non deterministic operation
47
- is done after creating the shards:
48
-
49
- .. code-block :: python
50
-
51
- from tensorflow.contrib.data import TextLineDataset
52
- from tensorflow.contrib.data import Dataset
53
-
54
- task_spec = TaskSpec()
55
- task_spec.create_server()
56
- files_dataset = Dataset.list_files(files_pattern)
57
- dataset = TextLineDataset(files_dataset)
58
- dataset = dataset.map(your_python_map_function, num_threads = 4 )
59
- if task_spec is not None :
60
- dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
61
- dataset = dataset.shuffle(buffer_size)
62
- dataset = dataset.batch(batch_size)
63
- dataset = dataset.repeat(num_epochs)
64
- iterator = dataset.make_one_shot_iterator()
65
- next_element = iterator.get_next()
66
- with tf.device(task_spec.device_fn()):
67
- tensors = create_graph(next_element)
68
- with tl.DistributedSession(task_spec = task_spec,
69
- checkpoint_dir = ' /tmp/ckpt' ) as session:
70
- while not session.should_stop():
71
- session.run(tensors)
72
-
73
-
74
- Logging
75
- ^^^^^^^^^^^^^^^^^^^^^^
76
-
77
- We can use task_spec to log only in the master server:
78
-
79
- .. code-block :: python
80
-
81
- while not session.should_stop():
82
- should_log = task_spec.is_master() and your_conditions
83
- if should_log:
84
- results = session.run(tensors_with_log_info)
85
- logging.info(... )
86
- else :
87
- results = session.run(tensors)
88
-
89
- Continuous evaluation
90
- ^^^^^^^^^^^^^^^^^^^^^^
91
-
92
- You can use one of the workers to run an evaluation for the saved checkpoints:
93
-
94
- .. code-block :: python
95
-
96
- import tensorflow as tf
97
- from tensorflow.python.training import session_run_hook
98
- from tensorflow.python.training.monitored_session import SingularMonitoredSession
99
-
100
- class Evaluator (session_run_hook .SessionRunHook ):
101
- def __init__ (self , checkpoints_path , output_path ):
102
- self .checkpoints_path = checkpoints_path
103
- self .summary_writer = tf.summary.FileWriter(output_path)
104
- self .lastest_checkpoint = ' '
105
-
106
- def after_create_session (self , session , coord ):
107
- checkpoint = tf.train.latest_checkpoint(self .checkpoints_path)
108
- # wait until a new check point is available
109
- while self .lastest_checkpoint == checkpoint:
110
- time.sleep(30 )
111
- checkpoint = tf.train.latest_checkpoint(self .checkpoints_path)
112
- self .saver.restore(session, checkpoint)
113
- self .lastest_checkpoint = checkpoint
114
-
115
- def end (self , session ):
116
- super (Evaluator, self ).end(session)
117
- # save summaries
118
- step = int (self .lastest_checkpoint.split(' -' )[- 1 ])
119
- self .summary_writer.add_summary(self .summary, step)
120
-
121
- def _create_graph ():
122
- # your code to create the graph with the dataset
123
-
124
- def run_evaluation ():
125
- with tf.Graph().as_default():
126
- summary_tensors = create_graph()
127
- self .saver = tf.train.Saver(var_list = tf_variables.trainable_variables())
128
- hooks = self .create_hooks()
129
- hooks.append(self )
130
- if self .max_time_secs and self .max_time_secs > 0 :
131
- hooks.append(StopAtTimeHook(self .max_time_secs))
132
- # this evaluation runs indefinitely, until the process is killed
133
- while True :
134
- with SingularMonitoredSession(hooks = [self ]) as session:
135
- try :
136
- while not sess.should_stop():
137
- self .summary = session.run(summary_tensors)
138
- except OutOfRangeError:
139
- pass
140
- # end of evaluation
141
-
142
- task_spec = TaskSpec().user_last_worker_as_evaluator()
143
- if task_spec.is_evaluator():
144
- Evaluator().run_evaluation()
145
- else :
146
- task_spec.create_server()
147
- # run normal training
148
-
149
-
150
-
151
- Session hooks
152
- ----------------------
153
-
154
- TensorFlow provides some `Session Hooks <https://www.tensorflow.org/api_guides/python/train#Training_Hooks >`_
155
- to do some operations in the sessions. We added more to help with common operations.
156
-
157
-
158
- Stop after maximum time
159
- ^^^^^^^^^^^^^^^^^^^^^^^
160
-
161
- .. autofunction :: StopAtTimeHook
21
+ .. autofunction :: Trainer
162
22
163
- Initialize network with checkpoint
164
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
165
23
166
- .. autofunction :: LoadCheckpoint
0 commit comments