-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
55 lines (42 loc) · 1.72 KB
/
main.py
File metadata and controls
55 lines (42 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
import rlg
from egg import core
import torch
def main(args):
# initialize egg
core.util.init()
device = core.get_opts().device
# initialize vision module
class_prediction = rlg.PretrainVision(rlg.Vision())
optimizer = core.build_optimizer(class_prediction.parameters())
class_prediction = class_prediction.to(device)
'''80% of data will be used for training. Alternativly, give a tuple.
E.g. (100, 20) for 100 training and 20 test images.'''
train_data_loader, test_data_loader = rlg.load_dataset(0.8)
for epoch in range(15):
mean_loss, n_batches = 0, 0
for batch_idx, (data, target) in enumerate(train_data_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = class_prediction(data)
loss = torch.nn.functional.l1_loss(output, target)
loss.backward()
optimizer.step()
mean_loss += loss.mean().item()
n_batches += 1
print(f'Train Epoch: {epoch}, mean loss: {mean_loss / n_batches}')
class_prediction.save("newly_trained")
# initialize game
# pass only the vision module, the rest was for pretraining
sender = rlg.Sender(class_prediction.vision_module)
receiver = rlg.Receiver()
# to train the game with reinforce, use LanguageGame, to train with gumbel softmax, use LanguageGameGS
game = rlg.LanguageGame(sender, receiver)
# game = rlg.LanguageGameGS(sender, receiver)
# train for 20 epochs
game.train2(20, train_data_loader, test_data_loader)
# plot an example communication
_, showcase = rlg.load_dataset(0.2)
game.plot(showcase)
if __name__ == "__main__":
main(sys.argv)