Skip to content

Commit

Permalink
Simple FedAvg workflow (#2157)
Browse files Browse the repository at this point in the history
* fedavg model controller

* use broadcast_and_wait; add FLComponentHelper class

* Use FLModel in aggregate_fn and update_model

* support FLModel in persistor

* restructure model controller

* simplify

* rename results

* formatting

* remove debug msg

* add docstrings

* check results

* rm unused file

* add phases and stats collection

* remove temp filenames

* add argument check

* add scaffold

* create ModelController

* restore persistance format manager and convert in ModelController

* Use FLComponentWrapper

* add experimental decorator

* add relay and wait

* update docstrings

* fix decorators; remove relay and wait

* reset job defaults

* reset default alpha

* formatting

* address comments

* address more comments

* restore base controller

* replace meta by default

* address comments
  • Loading branch information
holgerroth authored Dec 8, 2023
1 parent 9de8f31 commit 62f2f1d
Show file tree
Hide file tree
Showing 21 changed files with 1,085 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner",
"path": "pt.learners.CIFAR10ModelLearner",
"args": {
"aggregation_epochs": "{AGGREGATION_EPOCHS}",
"lr": 1e-2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {}
},
{
"id": "model_selector",
"name": "IntimeModelSelector",
Expand All @@ -50,19 +40,13 @@
],
"workflows": [
{
"id": "scatter_gather_ctl",
"name": "ScatterAndGather",
"args": {
"min_clients" : "{min_clients}",
"num_rounds" : "{num_rounds}",
"start_round": 0,
"wait_time_after_min_received": 10,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0
}
"id": "fedavg_ctl",
"name": "FedAvg",
"args": {
"min_clients": "{min_clients}",
"num_rounds": "{num_rounds}",
"persistor_id": "persistor"
}
},
{
"id": "cross_site_model_eval",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner",
"path": "pt.learners.CIFAR10ModelLearner",
"args": {
"train_idx_root": "{TRAIN_SPLIT_ROOT}",
"aggregation_epochs": "{AGGREGATION_EPOCHS}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,6 @@
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {}
},
{
"id": "model_selector",
"name": "IntimeModelSelector",
Expand All @@ -59,18 +49,12 @@
],
"workflows": [
{
"id": "scatter_gather_ctl",
"name": "ScatterAndGather",
"id": "fedavg_ctl",
"name": "FedAvg",
"args": {
"min_clients": "{min_clients}",
"num_rounds": "{num_rounds}",
"start_round": 0,
"wait_time_after_min_received": 10,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0
"persistor_id": "persistor"
}
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner",
"path": "pt.learners.CIFAR10ModelLearner",
"args": {
"train_idx_root": "{TRAIN_SPLIT_ROOT}",
"aggregation_epochs": "{AGGREGATION_EPOCHS}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner",
"path": "pt.learners.CIFAR10ModelLearner",
"args": {
"train_idx_root": "{TRAIN_SPLIT_ROOT}",
"aggregation_epochs": "{AGGREGATION_EPOCHS}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {}
},
{
"id": "model_selector",
"name": "IntimeModelSelector",
Expand All @@ -62,18 +52,12 @@
],
"workflows": [
{
"id": "scatter_gather_ctl",
"name": "ScatterAndGather",
"id": "fedavg_ctl",
"name": "FedAvg",
"args": {
"min_clients" : "{min_clients}",
"num_rounds" : "{num_rounds}",
"start_round": 0,
"wait_time_after_min_received": 10,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0
"min_clients": "{min_clients}",
"num_rounds": "{num_rounds}",
"persistor_id": "persistor"
}
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.learner_executor.LearnerExecutor",
"path": "nvflare.app_common.executors.model_learner_executor.ModelLearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
Expand All @@ -23,7 +23,7 @@
"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_scaffold_learner.CIFAR10ScaffoldLearner",
"path": "pt.learners.CIFAR10ScaffoldModelLearner",
"args": {
"train_idx_root": "{TRAIN_SPLIT_ROOT}",
"aggregation_epochs": "{AGGREGATION_EPOCHS}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,6 @@
}
}
},
{
"id": "shareable_generator",
"name": "FullModelShareableGenerator",
"args": {}
},
{
"id": "aggregator",
"name": "InTimeAccumulateWeightedAggregator",
"args": {
"expected_data_kind": {
"_model_weights_": "WEIGHT_DIFF",
"scaffold_c_diff": "WEIGHT_DIFF"
}
}
},
{
"id": "model_selector",
"name": "IntimeModelSelector",
Expand All @@ -64,18 +49,13 @@
],
"workflows": [
{
"id": "scatter_gather_ctl",
"name": "ScatterAndGatherScaffold",
"id": "scaffold_ctl",
"name": "Scaffold",
"args": {
"min_clients": "{min_clients}",
"num_rounds": "{num_rounds}",
"start_round": 0,
"wait_time_after_min_received": 10,
"aggregator_id": "aggregator",
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
"train_task_name": "train",
"train_timeout": 0

"persistor_id": "persistor"
}
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/advanced/cifar10/cifar10-sim/run_experiments.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env bash

export PYTHONPATH=${PWD}/..
export PYTHONPATH=${PYTHONPATH}:${PWD}/..

# download dataset
./prepare_data.sh
Expand Down
16 changes: 16 additions & 0 deletions examples/advanced/cifar10/pt/learners/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .cifar10_model_learner import CIFAR10ModelLearner
from .cifar10_scaffold_model_learner import CIFAR10ScaffoldModelLearner
2 changes: 2 additions & 0 deletions examples/advanced/cifar10/pt/learners/cifar10_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
from nvflare.app_common.abstract.learner_spec import Learner
from nvflare.app_common.app_constant import AppConstants, ModelName, ValidateType
from nvflare.app_opt.pt.fedproxloss import PTFedProxLoss
from nvflare.fuel.utils.deprecated import deprecated


@deprecated("Please use 'CIFAR10ModelLearner'")
class CIFAR10Learner(Learner): # also supports CIFAR10ScaffoldLearner
def __init__(
self,
Expand Down
Loading

0 comments on commit 62f2f1d

Please sign in to comment.