-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_slave.py
172 lines (148 loc) · 5.52 KB
/
multi_slave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import torch.distributed.deprecated as dist
from datasource import Mnist, Mnist_noniid, Cifar10, Cifar10_noniid
from model import CNNMnist, CNNCifar, ResNet18
import copy
from torch.multiprocessing import Process
import argparse
import time
import sys
import os
sys.stdout.flush()
LR = 0.001
MAX_ROUND = 200
ROUND_NUMBER_FOR_SAVE = 10
ROUND_NUMBER_FOR_REDUCE = 5
IID = True
DATA_SET = 'Mnist'
#DATA_SET = 'Cifar10'
MODEL = 'CNN'
#MODEL = 'ResNet18'
# get train data from file datasource.py
def get_local_data(size, rank, batchsize):
if IID == True:
if DATA_SET == 'Mnist':
train_loader = Mnist(batchsize).get_train_data()
if DATA_SET == 'Cifar10':
train_loader = Cifar10(batchsize).get_train_data()
else:
if DATA_SET == 'Mnist':
train_loader = Mnist_noniid(batchsize, size).get_train_data(rank)
if DATA_SET == 'Cifar10':
train_loader = Cifar10_noniid(batchsize, size).get_train_data(rank)
return train_loader
# get test data from file datasource.py
def get_testset():
if IID == True:
if DATA_SET == 'Mnist':
test_loader = Mnist().get_test_data()
if DATA_SET == 'Cifar10':
test_loader = Cifar10().get_test_data()
else:
if DATA_SET == 'Mnist':
test_loader = Mnist_noniid().get_test_data()
if DATA_SET == 'Cifar10':
test_loader = Cifar10_noniid().get_test_data()
for step, (b_x, b_y) in enumerate(test_loader):
test_x = b_x
test_y = b_y
return test_x, test_y
def init_param(model, src, group):
for param in model.parameters():
#print(param)
sys.stdout.flush()
dist.broadcast(param.data, src=src, group=group)
#print('done')
sys.stdout.flush()
# save model to avoid break
def save_model(model, round, rank):
print('===> Saving models...')
state = {
'state': model.state_dict(),
'round': round,
}
torch.save(state, 'autoencoder' + str(rank) + '.t7')
# load model if there is any
def load_model(model, group, rank):
print('===> Try resume from checkpoint')
if os.path.exists('autoencoder' + str(rank) + '.t7'):
checkpoint = torch.load('autoencoder' + str(rank) + '.t7')
model.load_state_dict(checkpoint['state'])
round = checkpoint['round']
print('===> Load last checkpoint data')
else:
round = 0
init_param(model, 0, group)
return model, round
# get average model param from different client
def all_reduce(model, size, group):
for param in model.parameters():
dist.all_reduce(param.data, op=dist.reduce_op.SUM, group=group)
param.data /= size
return model
# new method of exchange (global end) but it is hard to converge
def exchange(model, size, rank):
old_model = copy.deepcopy(model)
for param in old_model.parameters():
dist.isend( param.data, dst=(rank+1)%size )
for param in model.parameters():
dist.recv( param.data, src=(rank-1)%size )
return model
def run(size, rank, epoch, batchsize):
if MODEL == 'CNN' and DATA_SET == 'Mnist':
model = CNNMnist()
if MODEL == 'CNN' and DATA_SET == 'Cifar10':
model = CNNCifar()
if MODEL == 'ResNet18' and DATA_SET == 'Cifar10':
model = ResNet18()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
loss_func = torch.nn.CrossEntropyLoss()
train_loader = get_local_data(size, rank, batchsize)
if rank == 0 :
test_x, test_y = get_testset()
#fo = open("file_multi"+str(rank)+".txt", 'w')
group_list = [i for i in range(size)]
group = dist.new_group(group_list)
model, round = load_model(model, group, rank)
while round < MAX_ROUND:
sys.stdout.flush()
if rank == 0:
test_output = model(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('Round: ', round, ' Rank: ', rank, '| test accuracy: %.2f' % accuracy)
#fo.write(str(round) + " " + str(rank) + " " + str(accuracy) + "\n")
for epoch_cnt in range(epoch):
for step, (b_x, b_y) in enumerate(train_loader):
optimizer.zero_grad()
output = model(b_x)
loss = loss_func(output, b_y)
loss.backward()
optimizer.step()
# model = exchange(model, size, rank)
model = all_reduce(model, size, group)
# if (round+1) % ROUND_NUMBER_FOR_REDUCE == 0:
# model = all_reduce(model, size, group)
if (round+1) % ROUND_NUMBER_FOR_SAVE == 0:
save_model(model, round+1, rank)
round += 1
#fo.close()
def init_processes(size, rank, epoch, batchsize, run):
dist.init_process_group(backend='tcp', init_method='tcp://127.0.0.1:22222', world_size=size, rank=rank)
run(size, rank, epoch, batchsize)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--size', '-s', type=int, default=5)
parser.add_argument('--epoch', '-e', type=int, default=1)
parser.add_argument('--batchsize', '-b', type=int, default=100)
args = parser.parse_args()
size = args.size
epoch = args.epoch
batchsize = args.batchsize
processes = []
for rank in range(0, size):
p = Process(target=init_processes, args=(size, rank, epoch, batchsize, run))
p.start()
processes.append(p)
for p in processes:
p.join()