Skip to content

Commit

Permalink
multiply gates by mask_flat twice, as in mesh tensorflow code for top…
Browse files Browse the repository at this point in the history
…-n gating
  • Loading branch information
lucidrains committed Aug 21, 2023
1 parent 166c41c commit f9b8ce3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,
- [x] redo all the transcribed code from google with einops, as it is not very clear
- [x] consult some MoE experts in the open source community; question why hierarchical MoE is needed, in light of results from soft-MoE
- [x] offer top-n gating generalization, as it seems top3 (with smaller threshold) can work even better
- [x] figure out if there was an error in <a href="https://github.com/lucidrains/mixture-of-experts/blob/master/mixture_of_experts/mixture_of_experts.py#L210">a previous transcription</a> - yea there was an error
- [x] figure out if there was an error in <a href="https://github.com/lucidrains/mixture-of-experts/blob/master/mixture_of_experts/mixture_of_experts.py#L210">a previous transcription</a> - no there was not an error

- [ ] allow for different thresholds for second vs third routed expert
- [ ] improvise a `Top2GatingWithCoordinateDescent` for `MoE` without `importance`
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.19',
version = '0.0.20',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 3 additions & 0 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,21 @@ def forward(self, x):
mask_flat = reduce(mask, '... n e -> ... n', 'sum')

# (k, batch, sequence) - weighted assignment
# following https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py#L1903
gates = gates * mask_flat

# (batch, sequence, experts, expert_capacity)

N = None

gates = gates[..., N, N]
mask_flat = mask_flat[..., N, N]
one_hot_gate_indices = one_hot_gate_indices[..., N]
safe_one_hot_gates = safe_one_hot(positions.long(), expert_capacity)[..., N, :]

combine_tensor = reduce(
gates
* mask_flat
* one_hot_gate_indices
* safe_one_hot_gates
, 'k ... -> ...', 'sum')
Expand Down

0 comments on commit f9b8ce3

Please sign in to comment.