Skip to content

Commit

Permalink
researcher will want to log the unweighted auxiliary losses
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 11, 2023
1 parent 5d5f071 commit 977ee55
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
8 changes: 4 additions & 4 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def __init__(self,
self.router_z_loss_coef = router_z_loss_coef

def forward(self, x):
dispatch_tensor, combine_tensor, loss, router_z_loss = self.gate(x)
dispatch_tensor, combine_tensor, balance_loss, router_z_loss = self.gate(x)

# dispatch

Expand All @@ -599,12 +599,12 @@ def forward(self, x):

# losses

balance_loss = loss * self.balance_loss_coef
router_z_loss = router_z_loss * self.router_z_loss_coef
weighted_balance_loss = balance_loss * self.balance_loss_coef
weighted_router_z_loss = router_z_loss * self.router_z_loss_coef

# combine the losses

total_aux_loss = balance_loss + router_z_loss
total_aux_loss = weighted_balance_loss + weighted_router_z_loss

return MixtureOfExpertsReturn(output, total_aux_loss, balance_loss, router_z_loss)

Expand Down

0 comments on commit 977ee55

Please sign in to comment.