-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain_baselines.py
72 lines (52 loc) · 2.29 KB
/
main_baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from datasets import Data
from nodes import Node
from args import args_parser
from utils import *
from server_funct import *
from client_funct import *
import os
if __name__ == '__main__':
args = args_parser()
# Set random seeds
setup_seed(args.random_seed)
# Set GPU device
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
torch.cuda.set_device('cuda:'+args.device)
# Loading data
data = Data(args)
# Data-size-based aggregation weights
sample_size = []
for i in range(args.node_num):
sample_size.append(len(data.train_loader[i]))
size_weights = [i/sum(sample_size) for i in sample_size]
# Initialize the central node
# num_id equals to -1 stands for central node
central_node = Node(-1, data.test_loader[0], data.test_set, args)
# Initialize the client nodes
client_nodes = {}
for i in range(args.node_num):
client_nodes[i] = Node(i, data.train_loader[i], data.train_set, args)
# Start the FL training
final_test_acc_recorder = RunningAverage()
test_acc_recorder = []
for rounds in range(args.T):
print('===============Stage 1 The {:d}-th round==============='.format(rounds + 1))
lr_scheduler(rounds, client_nodes, args)
# Client update
client_nodes, train_loss = Client_update(args, client_nodes, central_node)
avg_client_acc = Client_validate(args, client_nodes)
print(args.server_method + args.client_method + ', averaged clients personalization acc is ', avg_client_acc)
# Partial select function
if args.select_ratio == 1.0:
select_list = [idx for idx in range(len(client_nodes))]
else:
select_list = generate_selectlist(client_nodes, args.select_ratio)
# Server update
central_node = Server_update(args, central_node, client_nodes, select_list, size_weights)
acc = validate(args, central_node, which_dataset = 'local')
print(args.server_method + args.client_method + ', global model test acc is ', acc)
test_acc_recorder.append(acc)
# Final acc recorder
if rounds >= args.T - 10:
final_test_acc_recorder.update(acc)
print(args.server_method + args.client_method + ', final_testacc is ', final_test_acc_recorder.value())