Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 33 additions & 37 deletions python/paddle/distributed/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
ParallelStrategy = core.ParallelStrategy


def init_parallel_env(backend='nccl'):
def init_parallel_env():
"""
Initialize parallel training environments in dynamic mode.
Initialize parallel training environment in dynamic graph mode.

Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .
.. note::
Now only supports initializing the GPU parallel training
environment and using NCCL for communication.

Returns:
None
Expand Down Expand Up @@ -89,14 +89,12 @@ def train():
dist.spawn(train)
"""

# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 1. gpu check
if not core.is_compiled_with_cuda():
raise NotImplementedError(
"Cannot initialize parallel environment in CPU-only version, now only "
"supports initializing the GPU parallel environment. Please recompile "
"or reinstall paddle with GPU support.")

# 2. check env
def _check_var_exists(var_name):
Expand All @@ -112,30 +110,28 @@ def _check_var_exists(var_name):
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")

# 3. init ParallelStrategy
# 3. init NCCL ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)

# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I call this function with CPU whl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, add cuda check for this function

_set_expected_place(place)

# init nccl context
parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()


def get_rank():
Expand Down Expand Up @@ -163,7 +159,7 @@ def get_rank():

def get_world_size():
"""
The number of trainers (number of processes participating in current job).
Returns the number of trainers (number of processes participating in current job).

Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
Expand Down
10 changes: 4 additions & 6 deletions python/paddle/distributed/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,6 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
func (function): The target function is called by spawned process.
This function need to be able to pickled, so it must be defined
at the top level of a module.
This function should be called as ``func(i, *args)``, ``i`` is
the process index and ``args`` contains other arguments as tuple.
args (tuple, optional): Arguments passed to ``func``.
nprocs (int, optional): Number of processed to start. Default: -1.
when nprocs is -1, the available device will be obtained from
Expand All @@ -246,8 +244,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available
CPU number is obtained from the environment variable CPU_NUM.
For example, export CPU_NUM=4, if the environment variable is not set,
the executor will add the variable to the environment variable and
set its value to 1.
the spawn method will add default value to the environment variable
and set its value to 1.
join (bool, optional): Perform a blocking join on all spawned processes.
Default: True.
daemon (bool, optional): The spawned processes' daemon flag. Default: False.
Expand All @@ -266,8 +264,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
such as 6170. Default: None;
(5) selected_gpus (string): The training process will run on the
selected_gpus, such as "0,1,2,3". Default: None;
(6) print_config: Print current parallel training config. Default: False;
(7) use_paddlecloud: Whether to use paddlecloud platform to run your
(6) print_config (bool): Print current parallel training config. Default: False;
(7) use_paddlecloud (bool): Whether to use paddlecloud platform to run your
multi-process job. Default: False.

Returns:
Expand Down
158 changes: 94 additions & 64 deletions python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,38 +349,53 @@ def scale_loss(self, loss):
Examples:
.. code-block:: python

import numpy as np
import paddle.fluid as fluid

place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
with fluid.dygraph.guard(place):

# prepare the data parallel context
strategy = fluid.dygraph.prepare_context()

linear = fluid.dygraph.Linear(1, 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=linear.parameters())

# make the module become the data parallelism module
linear = fluid.dygraph.DataParallel(linear, strategy)

x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = fluid.dygraph.to_variable(x_data)

hidden = linear(data)
avg_loss = fluid.layers.mean(hidden)

# scale the loss according to the number of trainers.
avg_loss = linear.scale_loss(avg_loss)

avg_loss.backward()

# collect the gradients of trainers.
linear.apply_collective_grads()

adam.minimize(avg_loss)
linear.clear_gradients()
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist

class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)

def forward(self, x):
return self._linear2(self._linear1(x))

def train():
# 1. enable dynamic mode
paddle.disable_static()

# 2. initialize parallel environment
dist.init_parallel_env()

# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)

loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())

# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)

loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()

adam.step()
adam.clear_grad()

if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
"""
if not self._is_data_parallel_mode():
return loss
Expand Down Expand Up @@ -438,38 +453,53 @@ def apply_collective_grads(self):
Examples:
.. code-block:: python

import numpy as np
import paddle.fluid as fluid

place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
with fluid.dygraph.guard(place):

# prepare the data parallel context
strategy = fluid.dygraph.prepare_context()

linear = fluid.dygraph.Linear(1, 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=linear.parameters())

# make the module become the data parallelism module
linear = fluid.dygraph.DataParallel(linear, strategy)

x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = fluid.dygraph.to_variable(x_data)

hidden = linear(data)
avg_loss = fluid.layers.mean(hidden)

# scale the loss according to the number of trainers.
avg_loss = linear.scale_loss(avg_loss)

avg_loss.backward()

# collect the gradients of trainers.
linear.apply_collective_grads()

adam.minimize(avg_loss)
linear.clear_gradients()
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist

class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)

def forward(self, x):
return self._linear2(self._linear1(x))

def train():
# 1. enable dynamic mode
paddle.disable_static()

# 2. initialize parallel environment
dist.init_parallel_env()

# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)

loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())

# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)

loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()

adam.step()
adam.clear_grad()

if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
"""
if not self._is_data_parallel_mode():
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,9 @@
# executed in the python3 sub-process.


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestInitParallelEnv(unittest.TestCase):
def test_beckend_type_error(self):
with self.assertRaises(TypeError):
dist.init_parallel_env(backend=1)

def test_backend_value_error(self):
with self.assertRaises(ValueError):
dist.init_parallel_env(backend="mpi")

def test_check_env_failed(self):
os.environ['FLAGS_selected_gpus'] = '0'
os.environ['PADDLE_TRAINER_ID'] = '0'
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
]

__all__ += [
'grad', 'LayerList', 'load', 'save', 'prepare_context', 'to_variable',
'no_grad', 'ParallelEnv', 'DataParallel'
'grad', 'LayerList', 'load', 'save', 'to_variable', 'no_grad',
'DataParallel'
]

__all__ += [
Expand Down