From 97a56888e373363c6c5cd1211701939f76967ab3 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 25 Aug 2023 17:14:07 -0700 Subject: [PATCH] add ability to use differentiable topk --- setup.py | 2 +- st_moe_pytorch/st_moe_pytorch.py | 43 +++++++++++++++++++++++--------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index a3e1055..602dea8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.21', + version = '0.0.22', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index 0bb5319..01ddd85 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -11,7 +11,7 @@ from einops import rearrange, repeat, reduce, pack, unpack -from colt5_attention import topk as differentiable_topk +from colt5_attention import topk as maybe_differentiable_topk # constants @@ -156,13 +156,23 @@ def __init__( threshold_eval: Union[float, Tuple[float, ...]] = 0.2, capacity_factor_train = 1.25, capacity_factor_eval = 2., - straight_through_dispatch_tensor = True + straight_through_dispatch_tensor = True, + differentiable_topk = False, + differentiable_topk_fused = True ): super().__init__() self.eps = eps self.num_gates = num_gates self.to_gates = nn.Linear(dim, num_gates, bias = False) + self.differentiable_topk = differentiable_topk + + self.topk = partial( + maybe_differentiable_topk, + non_differentiable = not differentiable_topk, + fused = differentiable_topk_fused # use triton fused coordinate descent if possible by default + ) + assert top_n >= 2, 'must be 2 or more experts' self.top_n = top_n top_n_minus_1 = top_n - 1 @@ -214,7 +224,17 @@ def forward(self, x): # find top N experts per position - gates, gate_indices = raw_gates.topk(k = top_n, dim = -1) + topk_return = self.topk(raw_gates, k = top_n) + + gate_indices = topk_return.indices + + if self.differentiable_topk: + # allow for differentiable topk using coordinate descent + # used successfully for routing from CoLT5 paper https://github.com/lucidrains/CoLT5-attention + + gates = topk_return.coor_descent_values + else: + gates = topk_return.values # move the top-n dimension to be first @@ -342,25 +362,24 @@ def __init__(self, loss_coef = 1e-2, router_z_loss_coef = 1e-3, experts: Optional[Module] = None, - straight_through_dispatch_tensor = True + straight_through_dispatch_tensor = True, + differentiable_topk = False, + differentiable_topk_fused = True ): super().__init__() self.dim = dim self.num_experts = num_experts - gating_kwargs = dict( - threshold_train = threshold_train, - threshold_eval = threshold_eval, - capacity_factor_train = capacity_factor_train, - capacity_factor_eval = capacity_factor_eval - ) - self.gate = TopNGating( dim, top_n = gating_top_n, num_gates = num_experts, straight_through_dispatch_tensor = straight_through_dispatch_tensor, - **gating_kwargs + differentiable_topk = differentiable_topk, + threshold_train = threshold_train, + threshold_eval = threshold_eval, + capacity_factor_train = capacity_factor_train, + capacity_factor_eval = capacity_factor_eval ) self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_mult = expert_hidden_mult))