Skip to content

Commit

Permalink
add ability to use differentiable topk
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 26, 2023
1 parent 22dfd4d commit 97a5688
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.21',
version = '0.0.22',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
43 changes: 31 additions & 12 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from einops import rearrange, repeat, reduce, pack, unpack

from colt5_attention import topk as differentiable_topk
from colt5_attention import topk as maybe_differentiable_topk

# constants

Expand Down Expand Up @@ -156,13 +156,23 @@ def __init__(
threshold_eval: Union[float, Tuple[float, ...]] = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
straight_through_dispatch_tensor = True
straight_through_dispatch_tensor = True,
differentiable_topk = False,
differentiable_topk_fused = True
):
super().__init__()
self.eps = eps
self.num_gates = num_gates
self.to_gates = nn.Linear(dim, num_gates, bias = False)

self.differentiable_topk = differentiable_topk

self.topk = partial(
maybe_differentiable_topk,
non_differentiable = not differentiable_topk,
fused = differentiable_topk_fused # use triton fused coordinate descent if possible by default
)

assert top_n >= 2, 'must be 2 or more experts'
self.top_n = top_n
top_n_minus_1 = top_n - 1
Expand Down Expand Up @@ -214,7 +224,17 @@ def forward(self, x):

# find top N experts per position

gates, gate_indices = raw_gates.topk(k = top_n, dim = -1)
topk_return = self.topk(raw_gates, k = top_n)

gate_indices = topk_return.indices

if self.differentiable_topk:
# allow for differentiable topk using coordinate descent
# used successfully for routing from CoLT5 paper https://github.com/lucidrains/CoLT5-attention

gates = topk_return.coor_descent_values
else:
gates = topk_return.values

# move the top-n dimension to be first

Expand Down Expand Up @@ -342,25 +362,24 @@ def __init__(self,
loss_coef = 1e-2,
router_z_loss_coef = 1e-3,
experts: Optional[Module] = None,
straight_through_dispatch_tensor = True
straight_through_dispatch_tensor = True,
differentiable_topk = False,
differentiable_topk_fused = True
):
super().__init__()
self.dim = dim
self.num_experts = num_experts

gating_kwargs = dict(
threshold_train = threshold_train,
threshold_eval = threshold_eval,
capacity_factor_train = capacity_factor_train,
capacity_factor_eval = capacity_factor_eval
)

self.gate = TopNGating(
dim,
top_n = gating_top_n,
num_gates = num_experts,
straight_through_dispatch_tensor = straight_through_dispatch_tensor,
**gating_kwargs
differentiable_topk = differentiable_topk,
threshold_train = threshold_train,
threshold_eval = threshold_eval,
capacity_factor_train = capacity_factor_train,
capacity_factor_eval = capacity_factor_eval
)

self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_mult = expert_hidden_mult))
Expand Down

0 comments on commit 97a5688

Please sign in to comment.