From 8f6fa4d63bd99b280214c6087379b520328ed18b Mon Sep 17 00:00:00 2001 From: HU XU <110996408+XuHu0529@users.noreply.github.com> Date: Wed, 1 Mar 2023 15:43:42 +0800 Subject: [PATCH 1/2] Update 05_ddp.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sbp示例代码中,DistributedSampler封装使dataloader进行分布式数据划分 --- cn/docs/parallelism/05_ddp.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cn/docs/parallelism/05_ddp.md b/cn/docs/parallelism/05_ddp.md index 0244de76..98c19ab3 100644 --- a/cn/docs/parallelism/05_ddp.md +++ b/cn/docs/parallelism/05_ddp.md @@ -35,8 +35,10 @@ download=True, ) + sampler = flow.utils.data.distributed.DistributedSampler(training_data) + train_dataloader = flow.utils.data.DataLoader( - training_data, BATCH_SIZE, shuffle=True + training_data, BATCH_SIZE, shuffle=(sampler is None), sampler=sampler ) model = flowvision.models.mobilenet_v2().to(DEVICE) @@ -48,6 +50,7 @@ for t in range(EPOCH_NUM): print(f"Epoch {t+1}\n-------------------------------") + train_dataloader.sampler.set_epoch(t) size = len(train_dataloader.dataset) for batch, (x, y) in enumerate(train_dataloader): x = x.to_global(placement=PLACEMENT, sbp=S0) From 41b46b4a1e08977690196b69d5de95f77f2417b2 Mon Sep 17 00:00:00 2001 From: brandonliu2 Date: Fri, 14 Apr 2023 04:27:10 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0batch=5Fsize=E6=98=AF?= =?UTF-8?q?=E5=8D=95=E6=9C=BA=E5=8D=95=E5=8D=A1=E9=99=A4=E4=BB=A52?= =?UTF-8?q?=E7=9A=84=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cn/docs/parallelism/05_ddp.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cn/docs/parallelism/05_ddp.md b/cn/docs/parallelism/05_ddp.md index 98c19ab3..2c992cdd 100644 --- a/cn/docs/parallelism/05_ddp.md +++ b/cn/docs/parallelism/05_ddp.md @@ -91,6 +91,8 @@ y = y.to_global(placement=PLACEMENT, sbp=S0) ``` +- 需要注意的是,在进行分布式并行训练时,代码中规定的`BATCH_SIZE`为每一台机器的本地值而非`GLOBAL_BATCH_SIZE`,故上述代码单机双卡`BATCH_SIZE=64`的训练效果与单机单卡`BATCH_SIZE=128`一致。 + 这样,按照 [常见的分布式并行策略](./01_introduction.md) 中的介绍,我们就通过对数据进行 `split(0)` 切分,对模型进行广播,进行了分布式数据并行训练。 ## 使用 DistributedDataParallel 做数据并行训练