Skip to content

Commit 45aabd1

Browse files
committed
Update
0 parents  commit 45aabd1

29 files changed

+16730
-0
lines changed

.gitignore

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
**/__pycache__/
2+
**/LOGS/
3+
**/RESULTS/
4+
**/FIGURES/
5+
**/DataPrepare/
6+
.vscode
7+
.DS_Store
8+
dataloader_test.py
9+
local_test.py

DATA/TransportModes/CHI-taxi/tripdata_full.csv

+2,185
Large diffs are not rendered by default.

DATA/TransportModes/CHI-taxi/weatherdata_full.csv

+2,185
Large diffs are not rendered by default.

DATA/TransportModes/NYC-bike/tripdata_full.csv

+2,185
Large diffs are not rendered by default.

DATA/TransportModes/NYC-bike/weatherdata_full.csv

+2,185
Large diffs are not rendered by default.

DATA/TransportModes/NYC-taxi/tripdata_full.csv

+2,185
Large diffs are not rendered by default.

DATA/TransportModes/NYC-taxi/weatherdata_full.csv

+2,185
Large diffs are not rendered by default.

Experiments.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
This script starts a server and multiple client processes for federated learning.
3+
4+
The server process runs a Federated Learning server that orchestrates the federated learning process.
5+
The client processes simulate multiple clients participating in the federated learning process.
6+
7+
---
8+
Reference: "Federated Dynamic Model For Spatiotemporal Data Forecasting In Transportation", submitted to IEEE Transactions on Intelligent Transportation Systems, Jan 2025
9+
Github: https://github.com/nhat-thien/Fed-LSTM-DSTGCRN
10+
"""
11+
12+
from datetime import datetime
13+
import subprocess
14+
import time
15+
import json
16+
import os
17+
from FL_HELPERS.FL_constants import GENERAL_INFO
18+
from Hyperparameters import get_hyperparameters
19+
from TestCase import get_clients_configs, TEST_CASES
20+
21+
# Suppress TensorFlow warnings
22+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
23+
24+
# Huge faster than INTEL MKL
25+
os.environ['MKL_THREADING_LAYER'] = 'GNU'
26+
27+
#--------------------------------------------------------------------
28+
# Because we known beforehand the number of GPUs in the machine.
29+
# You can set it to [0] if you have only one GPU, then make change
30+
# accordingly where GPU_IDs is used.
31+
#--------------------------------------------------------------------
32+
GPU_IDs = [0, 1, 2, 3]
33+
34+
35+
def main(is_FL, TEST_CASE_ID=1, PARAM_ID=0):
36+
37+
#----------------------------------------------------------------
38+
# START THE SERVER
39+
#----------------------------------------------------------------
40+
server_start_time = time.time()
41+
num_clients = len(TEST_CASES[TEST_CASE_ID].clients_names)
42+
python_command = (
43+
f"from FL_HELPERS.FL_subprocess import run_server;"
44+
f"from TestCase import get_clients_configs, TEST_CASES;"
45+
f"from Hyperparameters import get_hyperparameters;"
46+
f"clients_configs = get_clients_configs(TEST_CASES[{TEST_CASE_ID}]);"
47+
f"params = get_hyperparameters(clients_configs[0].get('model_name'), {is_FL});"
48+
f"run_server(params[{PARAM_ID}], {num_clients}, is_FL={is_FL})"
49+
)
50+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
51+
server_process = subprocess.Popen(["python", "-c", python_command])
52+
time.sleep(5)
53+
print(f"{'-'*90}\n{GENERAL_INFO} SERVER STARTED\n{'-'*90}")
54+
55+
56+
57+
#----------------------------------------------------------------
58+
# RUN SUBPROCESSES FOR CLIENTS
59+
#----------------------------------------------------------------
60+
client_processes = []
61+
clients_configs = get_clients_configs(TEST_CASES[TEST_CASE_ID])
62+
63+
for i, client_config in enumerate(clients_configs):
64+
65+
python_command = (
66+
f"from FL_HELPERS.FL_subprocess import run_client;"
67+
f"from TestCase import get_clients_configs, TEST_CASES;"
68+
f"from Hyperparameters import get_hyperparameters;"
69+
f"clients_configs = get_clients_configs(TEST_CASES[{TEST_CASE_ID}]);"
70+
f"params = get_hyperparameters(clients_configs[{i}].get('model_name'), {is_FL});"
71+
f"params[{PARAM_ID}].device = 'cuda:{GPU_IDs[i]}';"
72+
f"run_client(params[{PARAM_ID}], clients_configs[{i}], is_FL={is_FL}, device='cuda:{GPU_IDs[i]}')"
73+
)
74+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
75+
client_process = subprocess.Popen(["python", "-c", python_command])
76+
print(f"{GENERAL_INFO} {client_config['client_name']}: Client machine started")
77+
client_processes.append(client_process)
78+
time.sleep(1)
79+
80+
81+
# Wait for the server process to finish
82+
server_process.wait()
83+
84+
85+
# Measure the time when the server stops
86+
server_stop_time_seconds = time.time() - server_start_time
87+
print(f"{'='*90}\nServer stopped at: {server_stop_time_seconds:0.2f} seconds\n{'='*90}\n\n")
88+
89+
90+
# Load the JSON file
91+
server_log_path = os.path.join(os.getcwd(), f"{TEST_CASES[TEST_CASE_ID].results_path}/SERVER_training_logs.json")
92+
try:
93+
with open(server_log_path, "r") as file:
94+
server_logs = json.load(file)
95+
except FileNotFoundError:
96+
server_logs = []
97+
98+
99+
# Add a record for the server start and stop times
100+
record = {
101+
"started_at": datetime.fromtimestamp(server_start_time).isoformat(),
102+
"stopped_at": datetime.now().isoformat(),
103+
"duration": server_stop_time_seconds,
104+
"clients": TEST_CASES[TEST_CASE_ID].clients_names,
105+
"model_name": TEST_CASES[TEST_CASE_ID].model_name,
106+
"NOTE": f'Testcase {TEST_CASE_ID}: FL for all clients FedAvg, Attentive, AttentiveCSV, CSV'
107+
}
108+
server_logs.append(record)
109+
110+
111+
# Save the updated data to the JSON file
112+
os.makedirs(os.path.dirname(server_log_path), exist_ok=True)
113+
with open(server_log_path, "w") as file:
114+
json.dump(server_logs, file)
115+
116+
117+
118+
119+
120+
if __name__ == "__main__":
121+
122+
123+
for is_FL in [False]:
124+
125+
#--------------------------------------------------------------------
126+
# Test case ID, see TestCase.py
127+
#--------------------------------------------------------------------
128+
for TEST_CASE_ID in [0]:
129+
130+
131+
#----------------------------------------------------------------
132+
# Get hyperparameters for the corresponding test case
133+
#----------------------------------------------------------------
134+
clients_configs = get_clients_configs(TEST_CASES[TEST_CASE_ID])
135+
model_name = clients_configs[0].get('model_name')
136+
parmas = get_hyperparameters(model_name, is_FL)
137+
print(f"{GENERAL_INFO} There are {len(parmas)} combinations of hyperparameters")
138+
139+
140+
#----------------------------------------------------------------
141+
# MAIN LOOP over all hyperparameters configurations
142+
#----------------------------------------------------------------
143+
for i in range(len(parmas)):
144+
print(f"{GENERAL_INFO} TEST CASE {TEST_CASE_ID}: {model_name} -- {'CENTRALIZED LEARNING' if is_FL == False else f'FL SCHEME: {parmas[i].FL_scheme}'} -- CSV: {parmas[i].use_CSV}")
145+
main(is_FL, TEST_CASE_ID=TEST_CASE_ID, PARAM_ID=i)

0 commit comments

Comments
 (0)