From 977ee5500c000b47f270d6cc4ccd33bef380ca0f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Sep 2023 14:43:44 -0700 Subject: [PATCH] researcher will want to log the unweighted auxiliary losses --- setup.py | 2 +- st_moe_pytorch/st_moe_pytorch.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 4f3a4e1..5d92dae 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index cd74147..484fce2 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -583,7 +583,7 @@ def __init__(self, self.router_z_loss_coef = router_z_loss_coef def forward(self, x): - dispatch_tensor, combine_tensor, loss, router_z_loss = self.gate(x) + dispatch_tensor, combine_tensor, balance_loss, router_z_loss = self.gate(x) # dispatch @@ -599,12 +599,12 @@ def forward(self, x): # losses - balance_loss = loss * self.balance_loss_coef - router_z_loss = router_z_loss * self.router_z_loss_coef + weighted_balance_loss = balance_loss * self.balance_loss_coef + weighted_router_z_loss = router_z_loss * self.router_z_loss_coef # combine the losses - total_aux_loss = balance_loss + router_z_loss + total_aux_loss = weighted_balance_loss + weighted_router_z_loss return MixtureOfExpertsReturn(output, total_aux_loss, balance_loss, router_z_loss)