forked from sairadin/FLMcancer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver_new.py
103 lines (83 loc) · 3.28 KB
/
server_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Dict, Optional, Tuple, List, Union
from pathlib import Path
from grpc import server
import flwr as fl
import tensorflow as tf
from flwr.server.client_proxy import ClientProxy
from flwr.server.client_manager import ClientManager
from flwr.common import (
EvaluateRes,
FitRes,
Parameters,
Scalar,
Weights,
)
# from typing import Callable
import numpy as np
from datetime import datetime
import os
from tensorflow.keras import layers as L
# from tensorflow.keras.applications.efficientnet import EfficientNetB2 as efn
import efficientnet.tfkeras as efn
def load_model():
IMAGE_SIZE = [384, 384]
model = tf.keras.Sequential([
efn.EfficientNetB2(
input_shape=(*IMAGE_SIZE, 3),
weights='imagenet',
include_top=False
),
L.GlobalAveragePooling2D(),
L.Dense(1024, activation = 'relu'),
L.Dropout(0.3),
L.Dense(512, activation= 'relu'),
L.Dropout(0.2),
L.Dense(256, activation='relu'),
L.Dropout(0.2),
L.Dense(128, activation='relu'),
L.Dropout(0.1),
L.Dense(1, activation='sigmoid')
])
model.compile(
optimizer='Adam',
loss = 'binary_crossentropy',
metrics=['binary_crossentropy', 'accuracy'],
)
return model
model = load_model()
model.load_weights('./melamodel/melamodel_weights072.h5')
class SaveModelStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Optional[fl.common.Weights]:
if not results:
return None
# model = load_model()
# Weight accuracy of each client by number of examples used
accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
examples = [r.num_examples for _, r in results]
# Aggregate and print custom metric
accuracy_aggregated = sum(accuracies) / sum(examples)
print(f"Round {server_round} accuracy aggregated from client results: {accuracy_aggregated}")
# only 2 decimal places on accuracy_aggregated
accuracy_agg2 = round(accuracy_aggregated, 2)
aggregated_weights = super().aggregate_fit(server_round, results, failures)
aggregated_params, _ = aggregated_weights
if aggregated_params is not None:
aggregated_weights_h : List[np.ndarray] = fl.common.parameters_to_weights(aggregated_params)
# modell = tf.keras.models.clone_model(model)
model.set_weights(aggregated_weights_h)
print(f'Federated Learning session completed! The accuracy of the aggregated model is {accuracy_agg2}')
print(f"Saving round {server_round} model weights...")
date = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
model.save_weights(f"./workspace/clientResults/round-{server_round}-weights-{date}.h5")
return aggregated_weights
# Create strategy and run server
strategy = SaveModelStrategy(
# fraction_fit=0.01,
initial_parameters=fl.common.weights_to_parameters(model.get_weights())
)
fl.server.start_server(strategy=strategy, config={"num_rounds": 1})