Skip to content

Commit fa199ad

Browse files
authored
Fix OptimizerWrapper creation, test gradient clipping (#593)
1 parent fdda330 commit fa199ad

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

hivemind/moe/server/layers/optim.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import torch
22

33

4-
class OptimizerWrapper(torch.optim.Optimizer):
4+
class OptimizerWrapper:
55
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""
66

77
def __init__(self, optim: torch.optim.Optimizer):
8-
super().__init__(optim.param_groups, optim.defaults)
98
self.optim = optim
109

1110
@property

tests/test_training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
2020
SGD = partial(torch.optim.SGD, lr=0.05)
2121

2222
with background_server(
23-
num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1
23+
num_experts=2,
24+
device="cpu",
25+
optim_cls=SGD,
26+
hidden_dim=64,
27+
num_handlers=1,
28+
clip_grad_norm=1.0,
2429
) as server_peer_info:
2530
dht = DHT(initial_peers=server_peer_info.addrs, start=True)
2631
expert1, expert2 = create_remote_experts(

0 commit comments

Comments
 (0)