From d9f5f0891115ecd31290d92f510d56625eb2d417 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 20 Sep 2023 21:09:26 -0700 Subject: [PATCH] allow for noising of gates --- setup.py | 2 +- st_moe_pytorch/st_moe_pytorch.py | 37 +++++++++++++++++++++++++++----- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 5d92dae..ed8a524 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index 484fce2..b6c5961 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -83,6 +83,13 @@ def cumsum_exclusive(t, dim = -3): pre_padding = (0, 0) * num_pad_dims return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim) +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + # pytorch one hot throws an error if there are out of bound indices. # tensorflow, in contrast, does not throw an error @@ -378,7 +385,12 @@ def __init__( self.straight_through_dispatch_tensor = straight_through_dispatch_tensor self.register_buffer('zero', torch.zeros((1,)), persistent = False) - def forward(self, x): + def forward( + self, + x, + noise_gates = False, + noise_mult = 1. + ): """ einstein notation: @@ -407,6 +419,11 @@ def forward(self, x): # gate logits and gates gate_logits = self.to_gates(x) + + if noise_gates: + noise = gumbel_noise(gate_logits) + gate_logits = gate_logits + noise * noise_mult + raw_gates = gate_logits.softmax(dim = -1) # find top N experts per position @@ -582,8 +599,13 @@ def __init__(self, self.balance_loss_coef = balance_loss_coef self.router_z_loss_coef = router_z_loss_coef - def forward(self, x): - dispatch_tensor, combine_tensor, balance_loss, router_z_loss = self.gate(x) + def forward( + self, + x, + noise_gates = False, + noise_mult = 1. + ): + dispatch_tensor, combine_tensor, balance_loss, router_z_loss = self.gate(x, noise_gates = noise_gates, noise_mult = noise_mult) # dispatch @@ -630,7 +652,12 @@ def __init__( self.ff_before = Expert(dim, prenorm = True) if add_ff_before else None self.ff_after = Expert(dim, prenorm = True) if add_ff_after else None - def forward(self, x): + def forward( + self, + x, + noise_gates = False, + noise_mult = 1. + ): # feedforward before @@ -641,7 +668,7 @@ def forward(self, x): residual = x - moe_out, total_aux_loss, balance_loss, router_z_loss = self.moe(self.moe_prenorm(x)) + moe_out, total_aux_loss, balance_loss, router_z_loss = self.moe(self.moe_prenorm(x), noise_gates = noise_gates, noise_mult = noise_mult) x = moe_out + residual