Skip to content

Commit ec33758

Browse files
authored
[pls merge] docs and release (#814)
* docs and release * Update package_info.py * Update layers.rst
1 parent 88aa39b commit ec33758

File tree

3 files changed

+9
-153
lines changed

3 files changed

+9
-153
lines changed

docs/modules/distributed.rst

Lines changed: 5 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -3,164 +3,21 @@ API - Distributed Training
33

44
(Alpha release - usage might change later)
55

6-
Helper sessions and methods to run a distributed training.
7-
Check this `minst example <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_distributed.py>`_.
6+
Helper API to run a distributed training.
7+
Check these `examples <https://github.com/tensorlayer/tensorlayer/tree/master/examples/distributed_training>`_.
88

99
.. automodule:: tensorlayer.distributed
1010

1111
.. autosummary::
1212

13-
TaskSpecDef
14-
TaskSpec
15-
DistributedSession
16-
StopAtTimeHook
17-
LoadCheckpoint
18-
13+
Trainer
1914

2015
Distributed training
2116
--------------------
2217

23-
24-
TaskSpecDef
18+
Trainer
2519
^^^^^^^^^^^
2620

27-
.. autofunction:: TaskSpecDef
28-
29-
Create TaskSpecDef from environment variables
30-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31-
32-
.. autofunction:: TaskSpec
33-
34-
Distributed session object
35-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
36-
37-
.. autofunction:: DistributedSession
38-
39-
Data sharding
40-
^^^^^^^^^^^^^^^^^^^^^^
41-
42-
In some cases we want to shard the data among all the training servers and
43-
not use all the data in all servers. TensorFlow >=1.4 provides some helper classes
44-
to work with data that support data sharding: `Datasets <https://www.tensorflow.org/programmers_guide/datasets>`_
45-
46-
It is important in sharding that the shuffle or any non deterministic operation
47-
is done after creating the shards:
48-
49-
.. code-block:: python
50-
51-
from tensorflow.contrib.data import TextLineDataset
52-
from tensorflow.contrib.data import Dataset
53-
54-
task_spec = TaskSpec()
55-
task_spec.create_server()
56-
files_dataset = Dataset.list_files(files_pattern)
57-
dataset = TextLineDataset(files_dataset)
58-
dataset = dataset.map(your_python_map_function, num_threads=4)
59-
if task_spec is not None:
60-
dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
61-
dataset = dataset.shuffle(buffer_size)
62-
dataset = dataset.batch(batch_size)
63-
dataset = dataset.repeat(num_epochs)
64-
iterator = dataset.make_one_shot_iterator()
65-
next_element = iterator.get_next()
66-
with tf.device(task_spec.device_fn()):
67-
tensors = create_graph(next_element)
68-
with tl.DistributedSession(task_spec=task_spec,
69-
checkpoint_dir='/tmp/ckpt') as session:
70-
while not session.should_stop():
71-
session.run(tensors)
72-
73-
74-
Logging
75-
^^^^^^^^^^^^^^^^^^^^^^
76-
77-
We can use task_spec to log only in the master server:
78-
79-
.. code-block:: python
80-
81-
while not session.should_stop():
82-
should_log = task_spec.is_master() and your_conditions
83-
if should_log:
84-
results = session.run(tensors_with_log_info)
85-
logging.info(...)
86-
else:
87-
results = session.run(tensors)
88-
89-
Continuous evaluation
90-
^^^^^^^^^^^^^^^^^^^^^^
91-
92-
You can use one of the workers to run an evaluation for the saved checkpoints:
93-
94-
.. code-block:: python
95-
96-
import tensorflow as tf
97-
from tensorflow.python.training import session_run_hook
98-
from tensorflow.python.training.monitored_session import SingularMonitoredSession
99-
100-
class Evaluator(session_run_hook.SessionRunHook):
101-
def __init__(self, checkpoints_path, output_path):
102-
self.checkpoints_path = checkpoints_path
103-
self.summary_writer = tf.summary.FileWriter(output_path)
104-
self.lastest_checkpoint = ''
105-
106-
def after_create_session(self, session, coord):
107-
checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
108-
# wait until a new check point is available
109-
while self.lastest_checkpoint == checkpoint:
110-
time.sleep(30)
111-
checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
112-
self.saver.restore(session, checkpoint)
113-
self.lastest_checkpoint = checkpoint
114-
115-
def end(self, session):
116-
super(Evaluator, self).end(session)
117-
# save summaries
118-
step = int(self.lastest_checkpoint.split('-')[-1])
119-
self.summary_writer.add_summary(self.summary, step)
120-
121-
def _create_graph():
122-
# your code to create the graph with the dataset
123-
124-
def run_evaluation():
125-
with tf.Graph().as_default():
126-
summary_tensors = create_graph()
127-
self.saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
128-
hooks = self.create_hooks()
129-
hooks.append(self)
130-
if self.max_time_secs and self.max_time_secs > 0:
131-
hooks.append(StopAtTimeHook(self.max_time_secs))
132-
# this evaluation runs indefinitely, until the process is killed
133-
while True:
134-
with SingularMonitoredSession(hooks=[self]) as session:
135-
try:
136-
while not sess.should_stop():
137-
self.summary = session.run(summary_tensors)
138-
except OutOfRangeError:
139-
pass
140-
# end of evaluation
141-
142-
task_spec = TaskSpec().user_last_worker_as_evaluator()
143-
if task_spec.is_evaluator():
144-
Evaluator().run_evaluation()
145-
else:
146-
task_spec.create_server()
147-
# run normal training
148-
149-
150-
151-
Session hooks
152-
----------------------
153-
154-
TensorFlow provides some `Session Hooks <https://www.tensorflow.org/api_guides/python/train#Training_Hooks>`_
155-
to do some operations in the sessions. We added more to help with common operations.
156-
157-
158-
Stop after maximum time
159-
^^^^^^^^^^^^^^^^^^^^^^^
160-
161-
.. autofunction:: StopAtTimeHook
21+
.. autofunction:: Trainer
16222

163-
Initialize network with checkpoint
164-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16523

166-
.. autofunction:: LoadCheckpoint

docs/modules/layers.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,9 @@ BinaryConv2d
804804

805805
Ternary Dense Layer
806806
^^^^^^^^^^^^^^^^^^^^^^^^^^
807+
808+
TernaryDenseLayer
809+
"""""""""""""""""""""
807810
.. autoclass:: TernaryDenseLayer
808811

809812
Ternary Convolutions
@@ -849,10 +852,6 @@ QuanConv2dWithBN
849852
"""""""""""""""""""""
850853
.. autoclass:: QuanConv2dWithBN
851854

852-
DoReFa
853-
^^^^^^^^^^^^^^
854-
855-
see Convolutional and Dense APIs.
856855

857856
.. -----------------------------------------------------------
858857
.. Recurrent Layers

tensorlayer/package_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Deep learning and Reinforcement learning library for Researchers and Engineers"""
44

55
# Use the following formatting: (major, minor, patch, prerelease)
6-
VERSION = (1, 10, 0, "rc0")
6+
VERSION = (1, 10, 0, "")
77
__shortversion__ = '.'.join(map(str, VERSION[:3]))
88
__version__ = '.'.join(map(str, VERSION[:3])) + "".join(VERSION[3:])
99

0 commit comments

Comments
 (0)