diff --git a/setup.py b/setup.py index 3eb28a7..77ef885 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.1.6', + version = '0.1.7', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/distributed.py b/st_moe_pytorch/distributed.py index 448bd31..a62dd7c 100644 --- a/st_moe_pytorch/distributed.py +++ b/st_moe_pytorch/distributed.py @@ -86,25 +86,14 @@ def __init__(self, *, dim = 0): def forward(self, x, sizes = None): return AllGatherFunction.apply(x, self.dim, sizes) -class SplitByRankFunction(Function): - @staticmethod - def forward(ctx, x): - rank = dist.get_rank() - out = x[rank] - - if isinstance(x, tuple): - sizes = tuple(map(lambda t: t.shape[0], x)) - else: - sizes = (x.shape[1],) * x.shape[0] - - sizes = torch.tensor(sizes, device = out.device, dtype = torch.long) - ctx.sizes = sizes - return out, sizes +def split_by_rank(x): + rank = dist.get_rank() + out = x[rank] - @staticmethod - def backward(ctx, grads, _): - grads = rearrange(grads, '... -> 1 ...') - grads = all_gather_variable_dim(grads, sizes = ctx.sizes) - return grads + if isinstance(x, tuple): + sizes = tuple(map(lambda t: t.shape[0], x)) + else: + sizes = (x.shape[1],) * x.shape[0] -split_by_rank = SplitByRankFunction.apply + sizes = torch.tensor(sizes, device = out.device, dtype = torch.long) + return out, sizes