-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestCase.py
102 lines (95 loc) · 3.04 KB
/
TestCase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class TestCase:
"""
Represents a test case for a machine learning model.
Args:
model_name (str): The name of the machine learning model.
framework (str): The framework used for the model.
results_path (str): The path to store the results of the test case.
clients_names (list): A list of client names.
data_paths (list): A list of data paths for each client.
"""
def __init__(self, model_name, framework, results_path, clients_names, data_paths):
self.model_name = model_name
self.framework = framework
self.results_path = results_path
self.clients_names = clients_names
self.data_paths = data_paths
TRANSPORT_MODES = ['NYC-bike', 'NYC-taxi', 'CHI-taxi']
TRANSPORT_OPERATORS = ['LyonPT', 'Orange']
TEST_CASES = [
TestCase(#0
'Fed-LSTM-DSTGCRN',
'PyTorch',
'RESULTS/TransportModes',
TRANSPORT_MODES,
[f'DATA/TransportModes/{mode}/tripdata_full.csv' for mode in TRANSPORT_MODES],
),
TestCase(#1
'FedLSTM',
'PyTorch',
'RESULTS/TransportModes',
TRANSPORT_MODES,
[f'DATA/TransportModes/{mode}/tripdata_full.csv' for mode in TRANSPORT_MODES],
),
TestCase(#1
'FedGRU',
'PyTorch',
'RESULTS/TransportModes',
TRANSPORT_MODES,
[f'DATA/TransportModes/{mode}/tripdata_full.csv' for mode in TRANSPORT_MODES],
),
TestCase(#1
'FedAGCRN',
'PyTorch',
'RESULTS/TransportModes',
TRANSPORT_MODES,
[f'DATA/TransportModes/{mode}/tripdata_full.csv' for mode in TRANSPORT_MODES],
),
TestCase(#3
'Fed-LSTM-DSTGCRN',
'PyTorch',
'RESULTS/OD_Data',
TRANSPORT_OPERATORS,
[f'DATA/OD_Data/{client}_2022.csv' for client in TRANSPORT_OPERATORS],
),
TestCase(#2
'FedLSTM',
'PyTorch',
'RESULTS/OD_Data',
TRANSPORT_OPERATORS,
[f'DATA/OD_Data/{client}_2021.csv' for client in TRANSPORT_OPERATORS],
),
TestCase(#2
'FedGRU',
'PyTorch',
'RESULTS/OD_Data',
TRANSPORT_OPERATORS,
[f'DATA/OD_Data/{client}_2021.csv' for client in TRANSPORT_OPERATORS],
),
TestCase(#2
'FedAGCRN',
'PyTorch',
'RESULTS/OD_Data',
TRANSPORT_OPERATORS,
[f'DATA/OD_Data/{client}_2021.csv' for client in TRANSPORT_OPERATORS],
)
]
def get_clients_configs(test_case):
"""
Generates a list of client configurations based on the given test case.
Args:
test_case (TestCase): The test case object.
Returns:
list: A list of client configurations.
"""
clients_conf = []
for client_name, data_path in zip(test_case.clients_names, test_case.data_paths):
conf = {
'model_name': test_case.model_name,
'framework': test_case.framework,
'results_path': test_case.results_path,
'client_name': client_name,
'data_path': data_path
}
clients_conf.append(conf)
return clients_conf