From 5149788cb2e0730d1537b9711dcfc5c4b11a0f4b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 9 Aug 2022 10:53:31 +0800 Subject: [PATCH] Fix computing averaged loss in the aishell recipe. (#523) * Fix computing averaged loss in the aishell recipe. * Set find_unused_parameters optionally. --- .../ASR/pruned_transducer_stateless3/train.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 0e5291b214..feaef5cf62 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -22,8 +22,12 @@ Usage: ./prepare.sh + +# If you use a non-zero value for --datatang-prob, you also need to run ./prepare_aidatatang_200zh.sh +If you use --datatang-prob=0, then you don't need to run the above script. + export CUDA_VISIBLE_DEVICES="0,1,2,3" @@ -62,7 +66,6 @@ import torch import torch.multiprocessing as mp import torch.nn as nn - from aidatatang_200zh import AIDatatang200zh from aishell import AIShell from asr_datamodule import AsrDataModule @@ -344,7 +347,7 @@ def get_parser(): parser.add_argument( "--datatang-prob", type=float, - default=0.2, + default=0.0, help="""The probability to select a batch from the aidatatang_200zh dataset. If it is set to 0, you don't need to download the data @@ -945,7 +948,10 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) - loss_value = tot_loss["loss"] / tot_loss["frames"] + if datatang_train_dl is not None: + loss_value = tot_loss["loss"] / tot_loss["frames"] + else: + loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -1032,7 +1038,16 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + if params.datatang_prob > 0: + find_unused_parameters = True + else: + find_unused_parameters = False + + model = DDP( + model, + device_ids=[rank], + find_unused_parameters=find_unused_parameters, + ) optimizer = Eve(model.parameters(), lr=params.initial_lr)