Skip to content

Commit 5f8c9fd

Browse files
"First commit to publish source code of FADNet project"
1 parent 42419db commit 5f8c9fd

29 files changed

+2930
-5
lines changed

.gitignore

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Pycharm and github
2+
.idea/
3+
4+
5+
# Environment
6+
env/
7+
8+
# Dataset
9+
data/
10+
11+
# Graph resuls
12+
graph_ultils/results/
13+
14+
# Loggs, models and pretrained models/
15+
loggs/
16+
pretrained_models/
17+
18+
# Json and plots
19+
results/

README.md

+91-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,93 @@
1-
# FADNet: Deep Federated Learning for Autonomous Driving
1+
# Deep Federated Learning for Autonomous Driving
22

3-
## Abstract
4-
Autonomous driving is an active research topic in both academia and industry. However, most of the existing solutions focus on improving the accuracy by training learnable models with centralized large-scale data. Therefore, these methods do not take into account the user's privacy. In this paper, we present a new approach to learn autonomous driving policy while respecting privacy concerns. We propose a peer-to-peer Deep Federated Learning (DFL) approach to train deep architectures in a fully decentralized manner and remove the need for central orchestration. We design a new Federated Autonomous Driving network (FADNet) that can improve the model stability, ensure convergence, and handle imbalanced data distribution problems while is being trained with federated learning methods. Intensively experimental results on three datasets show that our approach with FADNet and DFL achieves superior accuracy compared with other recent methods. Furthermore, our approach can maintain privacy by not collecting user data to a central server.
3+
*Autonomous driving is an active research topic in both academia and industry. However, most of the existing solutions focus on improving the accuracy by training learnable models with centralized large-scale data. Therefore, these methods do not take into account the user's privacy. In this paper, we present a new approach to learn autonomous driving policy while respecting privacy concerns. We propose a peer-to-peer Deep Federated Learning (DFL) approach to train deep architectures in a fully decentralized manner and remove the need for central orchestration. We design a new Federated Autonomous Driving network (FADNet) that can improve the model stability, ensure convergence, and handle imbalanced data distribution problems while is being trained with federated learning methods. Intensively experimental results on three datasets show that our approach with FADNet and DFL achieves superior accuracy compared with other recent methods. Furthermore, our approach can maintain privacy by not collecting user data to a central server.*
54

6-
## Code & Data
7-
Coming soon
5+
![Fig-1](misc/FADNet.png)
6+
*<center>**Figure 1**: The architecture of our Federated Autonomous Driving Net (FADNet).</center>*
7+
8+
This repository is the implementation of a decentralized federated learning approach for Autonomous Driving. We benchmark our method on three public datasets: [Udacity](), [Carla](), and [Gazebo]().
9+
10+
For the detail, please refer to [link](https://arxiv.org/abs/2110.05754).
11+
12+
This repository is based on and inspired by @Othmane Marfoq [work](https://github.com/omarfoq/communication-in-cross-silo-fl). We sincerely thank for their sharing of the codes.
13+
14+
## Summary
15+
16+
* [Prerequisites](#prerequisites)
17+
* [Datasets](#datasets)
18+
* [Federated Learning for Autonomous Driving](#federated-learning-for-autonomous-driving)
19+
* [Training](#training)
20+
* [Pretrained models and Testing](#pretrained-models-and-testing)
21+
* [Citation](#citation)
22+
* [License](#license)
23+
* [More information](#more-information)
24+
25+
### Prerequisites
26+
27+
PYTHON 3.6
28+
29+
CUDA 9.2
30+
31+
Please install dependence package by run following command:
32+
```
33+
pip install -r requirements.txt
34+
```
35+
36+
### Datasets
37+
38+
* For GAZEBO dataset, we provide:
39+
* The original dataset and the split train/test dataset for GAIA network at [link](). You can download and extract them into "data/driving_gazebo/" folder.
40+
41+
* For CARLA dataset, we provide:
42+
* The original dataset and the split train/test dataset for GAIA network at [link](). You can download and extract them into "data/driving_carla/" folder.
43+
44+
### Federated Learning for Autonomous Driving
45+
46+
Important: Before running any command lines in this section, please run following command to access 'graph_utils' folder:
47+
```
48+
cd graph_utils
49+
```
50+
And now, you are in 'graph_utils' folder.
51+
* To generate networks for GAZEBO dataset and compute the cycle time for them:
52+
```
53+
bash generate_network_driving-gazebo.sh
54+
```
55+
56+
* To generate networks for CARLA dataset and compute the cycle time for them:
57+
```
58+
bash generate_network_driving-carla.sh
59+
```
60+
61+
### Training
62+
63+
* To train our method on GAZEBO dataset with GAIA network, run:
64+
65+
```
66+
bash train_gazebo_gaia.sh
67+
```
68+
69+
* To train our method on CARLA dataset with GAIA network, you can use the same setup on GAZEBO.
70+
71+
### Pretrained models and Testing
72+
73+
We provide the pretrained models which are trained on GAZEBO dataset with GAIA network by our method at the last epoch. Please download at [link]() and extracted them into the "pretrained_models/DRIVING-GAZEBO_GAIA" folder.
74+
75+
The models can be evaluated in GAZEBO train and test set via:
76+
```
77+
bash test_gazebo_gaia.sh
78+
```
79+
80+
### Citation
81+
82+
If you use this code as part of any published research, we'd really appreciate it if you could cite the following paper:
83+
84+
```
85+
Updating
86+
```
87+
88+
### License
89+
90+
MIT License
91+
92+
### More information
93+
AIOZ AI Homepage: https://ai.aioz.io

communication.py

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import os
2+
from abc import ABC, abstractmethod
3+
4+
import torch
5+
from torch.utils.tensorboard import SummaryWriter
6+
7+
from utils.utils import get_network, get_iterator, get_model, args_to_string, EXTENSIONS, logger_write_params, print_model
8+
import time
9+
class Network(ABC):
10+
def __init__(self, args):
11+
"""
12+
Abstract class representing a network of worker collaborating to train a machine learning model,
13+
each worker has a local model and a local data iterator.
14+
Should implement `mix` to precise how the communication is done
15+
:param args: parameters defining the network
16+
"""
17+
self.args = args
18+
self.device = args.device
19+
self.batch_size_train = args.bz_train
20+
self.batch_size_test = args.bz_test
21+
self.network = get_network(args.network_name, args.architecture, args.experiment)
22+
self.n_workers = self.network.number_of_nodes()
23+
self.local_steps = args.local_steps
24+
self.log_freq = args.log_freq
25+
self.fit_by_epoch = args.fit_by_epoch
26+
self.initial_lr = args.lr
27+
self.optimizer_name = args.optimizer
28+
self.lr_scheduler_name = args.decay
29+
30+
# create logger
31+
if args.save_logg_path == "":
32+
self.logger_path = os.path.join("loggs", args_to_string(args), args.architecture)
33+
else:
34+
self.logger_path = args.save_logg_path
35+
os.makedirs(self.logger_path, exist_ok=True)
36+
if not args.test:
37+
self.logger_write_param = logger_write_params(os.path.join(self.logger_path, 'log.txt'))
38+
else:
39+
self.logger_write_param = logger_write_params(os.path.join(self.logger_path, 'test.txt'))
40+
self.logger_write_param.write(args.__repr__())
41+
42+
self.logger_write_param.write('>>>>>>>>>> start time: ' + str(time.asctime()))
43+
self.time_start = time.time()
44+
self.time_start_update = self.time_start
45+
46+
self.logger = SummaryWriter(self.logger_path)
47+
48+
self.round_idx = 0 # index of the current communication round
49+
50+
# get data loaders
51+
self.train_dir = os.path.join("data", args.experiment, args.network_name, "train")
52+
self.test_dir = os.path.join("data", args.experiment, args.network_name, "test")
53+
54+
extension = EXTENSIONS["driving"] if "driving" in args.experiment else EXTENSIONS[args.experiment]
55+
self.train_path = os.path.join(self.train_dir, "train" + extension)
56+
self.test_path = os.path.join(self.test_dir, "test" + extension)
57+
58+
print('- Loading: > %s < dataset from: %s'%(args.experiment, self.train_path))
59+
self.train_iterator = get_iterator(args.experiment, self.train_path, self.device, self.batch_size_test)
60+
print('- Loading: > %s < dataset from: %s'%(args.experiment, self.test_path))
61+
self.test_iterator = get_iterator(args.experiment, self.test_path, self.device, self.batch_size_test)
62+
63+
self.workers_iterators = []
64+
train_data_size = 0
65+
print('>>>>>>>>>> Loading worker-datasets')
66+
for worker_id in range(self.n_workers):
67+
data_path = os.path.join(self.train_dir, str(worker_id) + extension)
68+
print('\t + Loading: > %s < dataset from: %s' % (args.experiment, data_path))
69+
self.workers_iterators.append(get_iterator(args.experiment, data_path, self.device, self.batch_size_train))
70+
train_data_size += len(self.workers_iterators[-1])
71+
72+
self.epoch_size = int(train_data_size / self.n_workers)
73+
74+
# create workers models
75+
self.workers_models = [get_model(args.experiment, self.device,
76+
optimizer_name=self.optimizer_name, lr_scheduler=self.lr_scheduler_name,
77+
initial_lr=self.initial_lr, epoch_size=self.epoch_size)
78+
for w_i in range(self.n_workers)]
79+
80+
# average model of all workers
81+
self.global_model = get_model(args.experiment,
82+
self.device,
83+
epoch_size=self.epoch_size)
84+
print_model(self.global_model.net, self.logger_write_param)
85+
86+
# write initial performance
87+
if not self.args.test:
88+
self.write_logs()
89+
90+
@abstractmethod
91+
def mix(self):
92+
pass
93+
94+
def write_logs(self):
95+
"""
96+
write train/test loss, train/tet accuracy for average model and local models
97+
and intra-workers parameters variance (consensus) adn save average model
98+
"""
99+
if (self.round_idx - 1) == 0:
100+
return None
101+
print('>>>>>>>>>> Evaluating')
102+
print('\t - train set')
103+
start_time = time.time()
104+
train_loss, train_rmse = self.global_model.evaluate_iterator(self.train_iterator)
105+
end_time_train = time.time()
106+
print('\t - test set')
107+
test_loss, test_rmse = self.global_model.evaluate_iterator(self.test_iterator)
108+
end_time_test = time.time()
109+
self.logger.add_scalar("Train/Loss", train_loss, self.round_idx)
110+
self.logger.add_scalar("Train/RMSE", train_rmse, self.round_idx)
111+
self.logger.add_scalar("Test/Loss", test_loss, self.round_idx)
112+
self.logger.add_scalar("Test/RMSE", test_rmse, self.round_idx)
113+
self.logger.add_scalar("Train/Time", end_time_train - start_time, self.round_idx)
114+
self.logger.add_scalar("Test/Time", end_time_test - end_time_train, self.round_idx)
115+
116+
# write parameter variance
117+
average_parameter = self.global_model.get_param_tensor()
118+
119+
param_tensors_by_workers = torch.zeros((average_parameter.shape[0], self.n_workers))
120+
121+
for ii, model in enumerate(self.workers_models):
122+
param_tensors_by_workers[:, ii] = model.get_param_tensor() - average_parameter
123+
124+
consensus = (param_tensors_by_workers ** 2).mean()
125+
self.logger.add_scalar("Consensus", consensus, self.round_idx)
126+
self.logger_write_param.write(f'\t Round: {self.round_idx} |Train Loss: {train_loss:.5f} |Train RMSE: {train_rmse:.5f} |Eval-train Time: {end_time_train - start_time:.3f}')
127+
self.logger_write_param.write(f'\t -----: {self.round_idx} |Test Loss: {test_loss:.5f} |Test RMSE: {test_rmse:.5f} |Eval-test Time: {end_time_test - end_time_train:.3f}')
128+
self.logger_write_param.write(f'\t -----: Time: {time.time() - self.time_start_update:.3f}')
129+
self.logger_write_param.write(f'\t -----: Total Time: {time.time() - self.time_start:.3f}')
130+
131+
self.time_start_update = time.time()
132+
if not self.args.test:
133+
self.save_models(round=self.round_idx)
134+
135+
def save_models(self, round):
136+
round_path = os.path.join(self.logger_path, 'round_%s' % round)
137+
os.makedirs(round_path, exist_ok=True)
138+
path_global = round_path + '/model_global.pth'
139+
model_dict = {
140+
'round': round,
141+
'model_state': self.global_model.net.state_dict()
142+
}
143+
torch.save(model_dict, path_global)
144+
for i in range(self.n_workers):
145+
path_silo = round_path + '/model_silo_%s.pth' % i
146+
model_dict = {
147+
'epoch': round,
148+
'model_state': self.workers_models[i].net.state_dict()
149+
}
150+
torch.save(model_dict, path_silo)
151+
152+
def load_models(self, round):
153+
self.round_idx = round
154+
round_path = os.path.join(self.logger_path, 'round_%s' % round)
155+
path_global = round_path + '/model_global.pth'
156+
print('loading %s' % path_global)
157+
model_data = torch.load(path_global)
158+
self.global_model.net.load_state_dict(model_data.get('model_state', model_data))
159+
for i in range(self.n_workers):
160+
path_silo = round_path + '/model_silo_%s.pth' % i
161+
print('loading %s' % path_silo)
162+
model_data = torch.load(path_silo)
163+
self.workers_models[i].net.load_state_dict(model_data.get('model_state', model_data))
164+
165+
class Peer2PeerNetwork(Network):
166+
def mix(self, write_results=True):
167+
"""
168+
:param write_results:
169+
Mix local model parameters in a gossip fashion
170+
"""
171+
# update workers
172+
for worker_id, model in enumerate(self.workers_models):
173+
model.net.to(self.device)
174+
if self.fit_by_epoch:
175+
model.fit_iterator(train_iterator=self.workers_iterators[worker_id],
176+
n_epochs=self.local_steps, verbose=0)
177+
else:
178+
model.fit_batches(iterator=self.workers_iterators[worker_id], n_steps=self.local_steps)
179+
180+
# write logs
181+
if ((self.round_idx - 1) % self.log_freq == 0) and write_results:
182+
for param_idx, param in enumerate(self.global_model.net.parameters()):
183+
param.data.fill_(0.)
184+
for worker_model in self.workers_models:
185+
param.data += (1 / self.n_workers) * list(worker_model.net.parameters())[param_idx].data.clone()
186+
self.write_logs()
187+
188+
# mix models
189+
for param_idx, param in enumerate(self.global_model.net.parameters()):
190+
temp_workers_param_list = [torch.zeros(param.shape).to(self.device) for _ in range(self.n_workers)]
191+
for worker_id, model in enumerate(self.workers_models):
192+
for neighbour in self.network.neighbors(worker_id):
193+
coeff = self.network.get_edge_data(worker_id, neighbour)["weight"]
194+
temp_workers_param_list[worker_id] += \
195+
coeff * list(self.workers_models[neighbour].net.parameters())[param_idx].data.clone()
196+
197+
for worker_id, model in enumerate(self.workers_models):
198+
for param_idx_, param_ in enumerate(model.net.parameters()):
199+
if param_idx_ == param_idx:
200+
param_.data = temp_workers_param_list[worker_id].clone()
201+
202+
self.round_idx += 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
echo "################"
2+
echo "gaia"
3+
python generate_networks.py gaia --experiment driving_carla --upload_capacity 1e10 --download_capacity 1e10
4+
echo "################"
5+
echo "amazon_us"
6+
python generate_networks.py amazon_us --experiment driving_carla --upload_capacity 1e10 --download_capacity 1e10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
echo "################"
2+
echo "gaia"
3+
python generate_networks.py gaia --experiment driving_gazebo --upload_capacity 1e10 --download_capacity 1e10
4+
echo "################"
5+
echo "amazon_us"
6+
python generate_networks.py amazon_us --experiment driving_gazebo --upload_capacity 1e10 --download_capacity 1e10

0 commit comments

Comments
 (0)