Skip to content

Commit

Permalink
update cifar10 and gnn examples (#2340)
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth authored Feb 1, 2024
1 parent 029815e commit edaf8d2
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# 4.1 Central vs. FedAvg
experiments = {
"cifar10_central": {"tag": "val_acc_local_model"},
"cifar10_central": {"tag": "val_acc_local_model", "alpha": 0.0},
"cifar10_fedavg": {"tag": "val_acc_global_model", "alpha": 1.0},
}

Expand Down Expand Up @@ -95,6 +95,8 @@ def main():
alpha = exp.get("alpha", None)
if alpha:
config_name = config_name + f"*alpha{alpha}"
else:
raise ValueError(f"Expected an alpha value to be provided but got alpha={alpha}")
eventfile = glob.glob(
os.path.join(client_results_root, config_name, "**", "app_site-1", "events.*"), recursive=True
)
Expand All @@ -116,7 +118,8 @@ def main():
try:
xsite_data[k].append(xsite_results["site-1"][k]["val_accuracy"])
except Exception as e:
raise ValueError(f"No val_accuracy for {k} in {xsite_file}!")
xsite_data[k].append(None)
print(f"Warning: No val_accuracy for {k} in {xsite_file}!")

print("Training TB data:")
print(pd.DataFrame(data))
Expand Down
7 changes: 1 addition & 6 deletions examples/advanced/cifar10/cifar10-sim/run_simulator.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@ n_clients=$4

# specify output workdir
RESULT_ROOT=/tmp/nvflare/sim_cifar10
if [ 1 -eq "$(echo "${alpha} > 0" | bc)" ]
then
out_workspace=${RESULT_ROOT}/${job}_alpha${alpha}
else
out_workspace=${RESULT_ROOT}/${job}
fi
out_workspace=${RESULT_ROOT}/${job}_alpha${alpha}

# run FL simulator
./set_alpha.sh "${job}" "${alpha}"
Expand Down
12 changes: 6 additions & 6 deletions examples/advanced/gnn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ python3 -m pip install -r requirements.txt
```
To support functions of PyTorch Geometric necessary for this example, we need extra dependencies. Please refer to [installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and install accordingly:
```
pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
python3 -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html
```

#### Job Template
Expand All @@ -46,8 +46,8 @@ nvflare job list_templates
We can see the "sag_gnn" template is available

#### Protein Classification
The PPI dataset is directly available via torch_geometric library, we randomly split the dataset to 2 subsets, one for each client.
First, we run the local training on each client, as well as the whole dataset.
The PPI dataset is directly available via torch_geometric library, we randomly split the dataset to 2 subsets, one for each client (`--client_id 1` and `--client_id 2`).
First, we run the local training on each client, as well as the whole dataset with `--client_id 0`.
```
python3 code/graphsage_protein_local.py --client_id 0
python3 code/graphsage_protein_local.py --client_id 1
Expand All @@ -64,7 +64,7 @@ For client configs, we set client_ids for each client, and the number of local e

For server configs, we set the number of rounds for federated training, the key metric for model selection, and the model class path with model hyperparameters.

With the produced job, we run the federated training on both clients via FedAvg using NVFlare Simulator.
With the produced job, we run the federated training on both clients via FedAvg using the NVFlare Simulator.
```
nvflare simulator -w /tmp/nvflare/gnn/protein_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_protein
```
Expand All @@ -74,7 +74,7 @@ We first download the Elliptic++ dataset to `/tmp/nvflare/datasets/elliptic_pp`
- `txs_classes.csv`: transaction id and its class (licit or illicit)
- `txs_edgelist.csv`: connections for transaction ids
- `txs_features.csv`: transaction id and its features
Then, we run the local training on each client, as well as the whole dataset.
Then, we run the local training on each client, as well as the whole dataset. Again, `--client_id 0` uses all data.
```
python3 code/graphsage_finance_local.py --client_id 0
python3 code/graphsage_finance_local.py --client_id 1
Expand All @@ -87,7 +87,7 @@ nvflare job create -force -j "/tmp/nvflare/jobs/gnn_finance" -w "sag_gnn" -sd "c
-f app_2/config_fed_client.conf app_script="graphsage_finance_fl.py" app_config="--client_id 2 --epochs 10" \
-f app_server/config_fed_server.conf num_rounds=7 key_metric="validation_auc" model_class_path="pyg_sage.SAGE" components[0].args.model.args.in_channels=165 components[0].args.model.args.hidden_channels=256 components[0].args.model.args.num_layers=3 components[0].args.model.args.num_classes=2
```
And with the produced job, we run the federated training on both clients via FedAvg using NVFlare Simulator.
And with the produced job, we run the federated training on both clients via FedAvg using the NVFlare Simulator.
```
nvflare simulator -w /tmp/nvflare/gnn/finance_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_finance
```
Expand Down
6 changes: 5 additions & 1 deletion nvflare/app_common/workflows/base_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType
Expand Down Expand Up @@ -142,5 +143,8 @@ def update_model(self, aggr_result):

self.model = FLModelUtils.update_model(self.model, aggr_result)

self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.model, private=True, sticky=True)
# persistor uses Learnable format to save model
ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta)
self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True)

self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE)
10 changes: 8 additions & 2 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def start_controller(self, fl_ctx: FLContext) -> None:
else:
self.model = FLModel(params_type=ParamsType.FULL, params={})

self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.model, private=True, sticky=True)
# persistor uses Learnable format to save model
ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta)
self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True)
self.event(AppEventType.INITIAL_MODEL_LOADED)

self.engine = self.fl_ctx.get_engine()
Expand Down Expand Up @@ -231,7 +233,11 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None:
result = client_task.result
client_name = client_task.client.name

self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True)

self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT)
self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx)
self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT)

# Turn result into FLModel
result_model = FLModelUtils.from_shareable(result)
Expand Down Expand Up @@ -270,7 +276,6 @@ def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLCo
)
return

self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True)
self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False)

@abstractmethod
Expand Down Expand Up @@ -307,6 +312,7 @@ def save_model(self):
) or self._current_round == self._num_rounds - 1:
self.info("Start persist model on server.")
self.event(AppEventType.BEFORE_LEARNABLE_PERSIST)
# persistor uses Learnable format to save model
ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta)
self.persistor.save(ml, self.fl_ctx)
self.event(AppEventType.AFTER_LEARNABLE_PERSIST)
Expand Down

0 comments on commit edaf8d2

Please sign in to comment.