-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·69 lines (58 loc) · 2.08 KB
/
main.py
File metadata and controls
executable file
·69 lines (58 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from data.Femnist_data import get_femnist_dataloaders
from data.Sheks_data import get_sheks_dataloaders
from data.data import get_data, get_train_loaders
from plots.plot import save_log
from solvers.baseline import BaseLine
from solvers.fedavg import FedAvg
from solvers.pw import PW
from utils import initialize_model
config = {
'solver': "fed_avg", # ['fed_avg', 'pw']
'model': "MLP", # ['basic', '2NN', 'LSTM', 'MLP']
'dataset': "FEMNIST", # ["Sheks", "cifar-10", "FEMNIST"]
'batch_size': 10,
'client_type': "n-iid", # ["iid", "n-iid"]
'num_clients': 100,
'blk_size': 80,
'sample_clients': .01, # C
'client_iterations': 2, # epochs E
'lr': 0.001,
'beta': 0.4,
'alpha': 0.1,
'lambda': 0.1,
}
initialize_model()
if config["dataset"] == "Sheks":
train_loaders, test_loaders = get_sheks_dataloaders(config)
elif config["dataset"] == "FEMNIST":
train_loaders, test_loaders = get_femnist_dataloaders(config)
else:
train_dataset, train_loader, test_loader = get_data(config)
test_loaders = [test_loader]
train_loaders = get_train_loaders(config, train_dataset)
if __name__ == "__main__":
fed_solver = None
if config['solver'] == "fed_avg":
fed_solver = FedAvg
elif config['solver'] == "pw":
fed_solver = PW
# elif config['solver'] == "fed_prox":
# fed_solver = FedProx
# elif config['solver'] == "fed_opt":
# fed_solver = FedOPT
# elif config['solver'] == "cwt":
# fed_solver = CWT
# elif config['solver'] == "scaffold":
# fed_solver = SCAFFOLD
# elif config['solver'] == "ditto":
# fed_solver = DITTO
elif config['solver'] == "baseline":
acc_s, loss_s = BaseLine(train_loader, test_loader, config).train()
save_log(config['solver'], acc_s, "acc")
save_log(config['solver'], loss_s, "loss")
exit()
else:
print("wrong algorithm selected!")
acc_s, loss_s, idx = fed_solver(train_loaders, test_loaders, config).train()
save_log(config['solver'], acc_s, "acc")
save_log(config['solver'], loss_s, "loss")