Skip to content

Commit

Permalink
Fix computing averaged loss in the aishell recipe. (k2-fsa#523)
Browse files Browse the repository at this point in the history
* Fix computing averaged loss in the aishell recipe.

* Set find_unused_parameters optionally.
  • Loading branch information
csukuangfj authored Aug 9, 2022
1 parent f24b76e commit 5149788
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions egs/aishell/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
Usage:
./prepare.sh
# If you use a non-zero value for --datatang-prob, you also need to run
./prepare_aidatatang_200zh.sh
If you use --datatang-prob=0, then you don't need to run the above script.
export CUDA_VISIBLE_DEVICES="0,1,2,3"
Expand Down Expand Up @@ -62,7 +66,6 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn

from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
Expand Down Expand Up @@ -344,7 +347,7 @@ def get_parser():
parser.add_argument(
"--datatang-prob",
type=float,
default=0.2,
default=0.0,
help="""The probability to select a batch from the
aidatatang_200zh dataset.
If it is set to 0, you don't need to download the data
Expand Down Expand Up @@ -945,7 +948,10 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)

loss_value = tot_loss["loss"] / tot_loss["frames"]
if datatang_train_dl is not None:
loss_value = tot_loss["loss"] / tot_loss["frames"]
else:
loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
Expand Down Expand Up @@ -1032,7 +1038,16 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.datatang_prob > 0:
find_unused_parameters = True
else:
find_unused_parameters = False

model = DDP(
model,
device_ids=[rank],
find_unused_parameters=find_unused_parameters,
)

optimizer = Eve(model.parameters(), lr=params.initial_lr)

Expand Down

0 comments on commit 5149788

Please sign in to comment.