From 62f2f1d2837d06788eccd762646d79bba20962cf Mon Sep 17 00:00:00 2001 From: Holger Roth <6304754+holgerroth@users.noreply.github.com> Date: Fri, 8 Dec 2023 16:30:30 -0500 Subject: [PATCH] Simple FedAvg workflow (#2157) * fedavg model controller * use broadcast_and_wait; add FLComponentHelper class * Use FLModel in aggregate_fn and update_model * support FLModel in persistor * restructure model controller * simplify * rename results * formatting * remove debug msg * add docstrings * check results * rm unused file * add phases and stats collection * remove temp filenames * add argument check * add scaffold * create ModelController * restore persistance format manager and convert in ModelController * Use FLComponentWrapper * add experimental decorator * add relay and wait * update docstrings * fix decorators; remove relay and wait * reset job defaults * reset default alpha * formatting * address comments * address more comments * restore base controller * replace meta by default * address comments --- .../config/config_fed_client.json | 2 +- .../config/config_fed_server.json | 30 +- .../config/config_fed_client.json | 2 +- .../config/config_fed_server.json | 22 +- .../config/config_fed_client.json | 2 +- .../config/config_fed_client.json | 2 +- .../config/config_fed_server.json | 26 +- .../config/config_fed_client.json | 4 +- .../config/config_fed_server.json | 28 +- .../cifar10/cifar10-sim/run_experiments.sh | 2 +- .../advanced/cifar10/pt/learners/__init__.py | 16 + .../cifar10/pt/learners/cifar10_learner.py | 2 + .../cifar10_scaffold_model_learner.py | 145 ++++++++ nvflare/app_common/abstract/model_learner.py | 181 +--------- nvflare/app_common/app_constant.py | 10 +- .../app_common/utils/fl_component_wrapper.py | 214 +++++++++++ nvflare/app_common/utils/fl_model_utils.py | 20 ++ nvflare/app_common/workflows/base_fedavg.py | 146 ++++++++ nvflare/app_common/workflows/fedavg.py | 61 ++++ .../app_common/workflows/model_controller.py | 331 ++++++++++++++++++ nvflare/app_common/workflows/scaffold.py | 119 +++++++ 21 files changed, 1085 insertions(+), 280 deletions(-) create mode 100644 examples/advanced/cifar10/pt/learners/__init__.py create mode 100644 examples/advanced/cifar10/pt/learners/cifar10_scaffold_model_learner.py create mode 100644 nvflare/app_common/utils/fl_component_wrapper.py create mode 100644 nvflare/app_common/workflows/base_fedavg.py create mode 100644 nvflare/app_common/workflows/fedavg.py create mode 100644 nvflare/app_common/workflows/model_controller.py create mode 100644 nvflare/app_common/workflows/scaffold.py diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_client.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_client.json index c9ae2c1ca5..a30c316abd 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_client.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_client.json @@ -25,7 +25,7 @@ "components": [ { "id": "cifar10-learner", - "path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner", + "path": "pt.learners.CIFAR10ModelLearner", "args": { "aggregation_epochs": "{AGGREGATION_EPOCHS}", "lr": 1e-2, diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_server.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_server.json index ed7900930e..968cda1c41 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_server.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_central/cifar10_central/config/config_fed_server.json @@ -20,16 +20,6 @@ } } }, - { - "id": "shareable_generator", - "name": "FullModelShareableGenerator", - "args": {} - }, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": {} - }, { "id": "model_selector", "name": "IntimeModelSelector", @@ -50,19 +40,13 @@ ], "workflows": [ { - "id": "scatter_gather_ctl", - "name": "ScatterAndGather", - "args": { - "min_clients" : "{min_clients}", - "num_rounds" : "{num_rounds}", - "start_round": 0, - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0 - } + "id": "fedavg_ctl", + "name": "FedAvg", + "args": { + "min_clients": "{min_clients}", + "num_rounds": "{num_rounds}", + "persistor_id": "persistor" + } }, { "id": "cross_site_model_eval", diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_client.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_client.json index 2cc6b86164..7d657dfa30 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_client.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_client.json @@ -23,7 +23,7 @@ "components": [ { "id": "cifar10-learner", - "path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner", + "path": "pt.learners.CIFAR10ModelLearner", "args": { "train_idx_root": "{TRAIN_SPLIT_ROOT}", "aggregation_epochs": "{AGGREGATION_EPOCHS}", diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_server.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_server.json index d4a7e5196e..353fa3328b 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_server.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedavg/cifar10_fedavg/config/config_fed_server.json @@ -29,16 +29,6 @@ } } }, - { - "id": "shareable_generator", - "name": "FullModelShareableGenerator", - "args": {} - }, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": {} - }, { "id": "model_selector", "name": "IntimeModelSelector", @@ -59,18 +49,12 @@ ], "workflows": [ { - "id": "scatter_gather_ctl", - "name": "ScatterAndGather", + "id": "fedavg_ctl", + "name": "FedAvg", "args": { "min_clients": "{min_clients}", "num_rounds": "{num_rounds}", - "start_round": 0, - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0 + "persistor_id": "persistor" } }, { diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedopt/cifar10_fedopt/config/config_fed_client.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedopt/cifar10_fedopt/config/config_fed_client.json index e98721acd2..950526610c 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedopt/cifar10_fedopt/config/config_fed_client.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedopt/cifar10_fedopt/config/config_fed_client.json @@ -27,7 +27,7 @@ "components": [ { "id": "cifar10-learner", - "path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner", + "path": "pt.learners.CIFAR10ModelLearner", "args": { "train_idx_root": "{TRAIN_SPLIT_ROOT}", "aggregation_epochs": "{AGGREGATION_EPOCHS}", diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_client.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_client.json index 70c6132143..0df56b5109 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_client.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_client.json @@ -27,7 +27,7 @@ "components": [ { "id": "cifar10-learner", - "path": "pt.learners.cifar10_model_learner.CIFAR10ModelLearner", + "path": "pt.learners.CIFAR10ModelLearner", "args": { "train_idx_root": "{TRAIN_SPLIT_ROOT}", "aggregation_epochs": "{AGGREGATION_EPOCHS}", diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_server.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_server.json index 2303715c91..ebab240e8c 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_server.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_fedprox/cifar10_fedprox/config/config_fed_server.json @@ -32,16 +32,6 @@ } } }, - { - "id": "shareable_generator", - "name": "FullModelShareableGenerator", - "args": {} - }, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": {} - }, { "id": "model_selector", "name": "IntimeModelSelector", @@ -62,18 +52,12 @@ ], "workflows": [ { - "id": "scatter_gather_ctl", - "name": "ScatterAndGather", + "id": "fedavg_ctl", + "name": "FedAvg", "args": { - "min_clients" : "{min_clients}", - "num_rounds" : "{num_rounds}", - "start_round": 0, - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0 + "min_clients": "{min_clients}", + "num_rounds": "{num_rounds}", + "persistor_id": "persistor" } }, { diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_client.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_client.json index 56f5cde54f..ee6768bc3e 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_client.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_client.json @@ -11,7 +11,7 @@ ], "executor": { "id": "Executor", - "path": "nvflare.app_common.executors.learner_executor.LearnerExecutor", + "path": "nvflare.app_common.executors.model_learner_executor.ModelLearnerExecutor", "args": { "learner_id": "cifar10-learner" } @@ -23,7 +23,7 @@ "components": [ { "id": "cifar10-learner", - "path": "pt.learners.cifar10_scaffold_learner.CIFAR10ScaffoldLearner", + "path": "pt.learners.CIFAR10ScaffoldModelLearner", "args": { "train_idx_root": "{TRAIN_SPLIT_ROOT}", "aggregation_epochs": "{AGGREGATION_EPOCHS}", diff --git a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_server.json b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_server.json index 01f38b4b18..aa14bed41e 100644 --- a/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_server.json +++ b/examples/advanced/cifar10/cifar10-sim/jobs/cifar10_scaffold/cifar10_scaffold/config/config_fed_server.json @@ -29,21 +29,6 @@ } } }, - { - "id": "shareable_generator", - "name": "FullModelShareableGenerator", - "args": {} - }, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": { - "expected_data_kind": { - "_model_weights_": "WEIGHT_DIFF", - "scaffold_c_diff": "WEIGHT_DIFF" - } - } - }, { "id": "model_selector", "name": "IntimeModelSelector", @@ -64,18 +49,13 @@ ], "workflows": [ { - "id": "scatter_gather_ctl", - "name": "ScatterAndGatherScaffold", + "id": "scaffold_ctl", + "name": "Scaffold", "args": { "min_clients": "{min_clients}", "num_rounds": "{num_rounds}", - "start_round": 0, - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0 + + "persistor_id": "persistor" } }, { diff --git a/examples/advanced/cifar10/cifar10-sim/run_experiments.sh b/examples/advanced/cifar10/cifar10-sim/run_experiments.sh index dcf4b731cb..9213c862b7 100755 --- a/examples/advanced/cifar10/cifar10-sim/run_experiments.sh +++ b/examples/advanced/cifar10/cifar10-sim/run_experiments.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -export PYTHONPATH=${PWD}/.. +export PYTHONPATH=${PYTHONPATH}:${PWD}/.. # download dataset ./prepare_data.sh diff --git a/examples/advanced/cifar10/pt/learners/__init__.py b/examples/advanced/cifar10/pt/learners/__init__.py new file mode 100644 index 0000000000..5d7bd1c363 --- /dev/null +++ b/examples/advanced/cifar10/pt/learners/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cifar10_model_learner import CIFAR10ModelLearner +from .cifar10_scaffold_model_learner import CIFAR10ScaffoldModelLearner diff --git a/examples/advanced/cifar10/pt/learners/cifar10_learner.py b/examples/advanced/cifar10/pt/learners/cifar10_learner.py index 8b8261723a..651f32eac6 100644 --- a/examples/advanced/cifar10/pt/learners/cifar10_learner.py +++ b/examples/advanced/cifar10/pt/learners/cifar10_learner.py @@ -32,8 +32,10 @@ from nvflare.app_common.abstract.learner_spec import Learner from nvflare.app_common.app_constant import AppConstants, ModelName, ValidateType from nvflare.app_opt.pt.fedproxloss import PTFedProxLoss +from nvflare.fuel.utils.deprecated import deprecated +@deprecated("Please use 'CIFAR10ModelLearner'") class CIFAR10Learner(Learner): # also supports CIFAR10ScaffoldLearner def __init__( self, diff --git a/examples/advanced/cifar10/pt/learners/cifar10_scaffold_model_learner.py b/examples/advanced/cifar10/pt/learners/cifar10_scaffold_model_learner.py new file mode 100644 index 0000000000..bba92ca6da --- /dev/null +++ b/examples/advanced/cifar10/pt/learners/cifar10_scaffold_model_learner.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch + +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.app_constant import AlgorithmConstants +from nvflare.app_opt.pt.scaffold import PTScaffoldHelper, get_lr_values + +from .cifar10_model_learner import CIFAR10ModelLearner + + +class CIFAR10ScaffoldModelLearner(CIFAR10ModelLearner): + def __init__( + self, + train_idx_root: str = "./dataset", + aggregation_epochs: int = 1, + lr: float = 1e-2, + fedproxloss_mu: float = 0.0, + central: bool = False, + analytic_sender_id: str = "analytic_sender", + batch_size: int = 64, + num_workers: int = 0, + ): + """Simple Scaffold CIFAR-10 Trainer. + Implements the training algorithm proposed in + Karimireddy et al. "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning" + (https://arxiv.org/abs/1910.06378) using functions implemented in `PTScaffoldHelper` class. + + Args: + train_idx_root: directory with site training indices for CIFAR-10 data. + aggregation_epochs: the number of training epochs for a round. Defaults to 1. + lr: local learning rate. Float number. Defaults to 1e-2. + fedproxloss_mu: weight for FedProx loss. Float number. Defaults to 0.0 (no FedProx). + central: Bool. Whether to simulate central training. Default False. + analytic_sender_id: id of `AnalyticsSender` if configured as a client component. + If configured, TensorBoard events will be fired. Defaults to "analytic_sender". + batch_size: batch size for training and validation. + num_workers: number of workers for data loaders. + + Returns: + a Shareable with the updated local model after running `execute()` + or the best local model depending on the specified task. + """ + + super().__init__( + train_idx_root=train_idx_root, + aggregation_epochs=aggregation_epochs, + lr=lr, + fedproxloss_mu=fedproxloss_mu, + central=central, + analytic_sender_id=analytic_sender_id, + batch_size=batch_size, + num_workers=num_workers, + ) + self.scaffold_helper = PTScaffoldHelper() + + def initialize(self): + # Initialize super class and SCAFFOLD + super().initialize() + self.scaffold_helper.init(model=self.model) + + def local_train(self, train_loader, model_global, val_freq: int = 0): + # local_train with SCAFFOLD steps + c_global_para, c_local_para = self.scaffold_helper.get_params() + for epoch in range(self.aggregation_epochs): + self.model.train() + epoch_len = len(train_loader) + self.epoch_global = self.epoch_of_start_time + epoch + self.info(f"Local epoch {self.site_name}: {epoch + 1}/{self.aggregation_epochs} (lr={self.lr})") + + for i, (inputs, labels) in enumerate(train_loader): + inputs, labels = inputs.to(self.device), labels.to(self.device) + # zero the parameter gradients + self.optimizer.zero_grad() + # forward + backward + optimize + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + # FedProx loss term + if self.fedproxloss_mu > 0: + fed_prox_loss = self.criterion_prox(self.model, model_global) + loss += fed_prox_loss + + loss.backward() + self.optimizer.step() + + # SCAFFOLD step + curr_lr = get_lr_values(self.optimizer)[0] + self.scaffold_helper.model_update( + model=self.model, curr_lr=curr_lr, c_global_para=c_global_para, c_local_para=c_local_para + ) + + current_step = epoch_len * self.epoch_global + i + self.writer.add_scalar("train_loss", loss.item(), current_step) + + if val_freq > 0 and epoch % val_freq == 0: + acc = self.local_valid(self.valid_loader, tb_id="val_acc_local_model") + if acc > self.best_acc: + self.save_model(is_best=True) + + # Update the SCAFFOLD terms + self.scaffold_helper.terms_update( + model=self.model, + curr_lr=curr_lr, + c_global_para=c_global_para, + c_local_para=c_local_para, + model_global=model_global, + ) + + def train(self, model: FLModel) -> Union[str, FLModel]: + # return FLModel with extra control differences for SCAFFOLD + if AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL not in model.meta: + raise ValueError( + f"Expected model meta to contain AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL " + f"but meta was {model.meta}.", + ) + global_ctrl_weights = model.meta.get(AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL) + if not global_ctrl_weights: + raise ValueError("global_ctrl_weights were empty!") + # convert to tensor and load into c_global model + for k in global_ctrl_weights.keys(): + global_ctrl_weights[k] = torch.as_tensor(global_ctrl_weights[k]) + self.scaffold_helper.load_global_controls(weights=global_ctrl_weights) + + # local training with global model weights + result_model = super().train(model) + + # Add scaffold controls to resulting model + result_model.meta[AlgorithmConstants.SCAFFOLD_CTRL_DIFF] = self.scaffold_helper.get_delta_controls() + + return result_model diff --git a/nvflare/app_common/abstract/model_learner.py b/nvflare/app_common/abstract/model_learner.py index bae80029de..762b8c8ba1 100644 --- a/nvflare/app_common/abstract/model_learner.py +++ b/nvflare/app_common/abstract/model_learner.py @@ -12,172 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Union -from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_context import FLContext -from nvflare.apis.fl_exception import TaskExecutionError from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper -class ModelLearner(FLComponent): - +class ModelLearner(FLComponentWrapper): STATE = None def __init__(self): super().__init__() - self.engine = None - self.fl_ctx = None - self.workspace = None - self.shareable = None - self.args = None - self.site_name = None - self.job_id = None - self.app_root = None - self.job_root = None - self.workspace_root = None - self.abort_signal = None - self.current_round = 0 - self.total_rounds = 0 - - def is_aborted(self) -> bool: - """Check whether the task has been asked to abort by the framework. - - Returns: whether the task has been asked to abort by the framework - - """ - return self.abort_signal and self.abort_signal.triggered - - def get_shareable_header(self, key: str, default=None): - """Convenience method for getting specified header from the shareable. - - Args: - key: name of the header - default: default value if the header doesn't exist - - Returns: value of the header if it exists in the shareable; or the specified default if it doesn't. - - """ - if not self.shareable: - return default - return self.shareable.get_header(key, default) - - def get_context_prop(self, key: str, default=None): - """Convenience method for getting specified property from the FL Context. - - Args: - key: name of the property - default: default value if the prop doesn't exist in FL Context - - Returns: value of the prop if it exists in the context; or the specified default if it doesn't. - - """ - if not self.fl_ctx: - return default - assert isinstance(self.fl_ctx, FLContext) - return self.fl_ctx.get_prop(key, default) - - def get_component(self, component_id: str) -> Any: - """Get the specified component from the context - - Args: - component_id: ID of the component - - Returns: the specified component if it is defined; or None if not. - - """ - if self.engine: - return self.engine.get_component(component_id) - else: - return None - - def debug(self, msg: str): - """Convenience method for logging a DEBUG message with contextual info - - Args: - msg: the message to be logged - - Returns: - - """ - self.log_debug(self.fl_ctx, msg) - - def info(self, msg: str): - """Convenience method for logging an INFO message with contextual info - - Args: - msg: the message to be logged - - Returns: - - """ - self.log_info(self.fl_ctx, msg) - - def error(self, msg: str): - """Convenience method for logging an ERROR message with contextual info - - Args: - msg: the message to be logged - - Returns: - - """ - self.log_error(self.fl_ctx, msg) - - def warning(self, msg: str): - """Convenience method for logging a WARNING message with contextual info - - Args: - msg: the message to be logged - - Returns: - - """ - self.log_warning(self.fl_ctx, msg) - - def exception(self, msg: str): - """Convenience method for logging an EXCEPTION message with contextual info - - Args: - msg: the message to be logged - - Returns: - - """ - self.log_exception(self.fl_ctx, msg) - - def critical(self, msg: str): - """Convenience method for logging a CRITICAL message with contextual info - - Args: - msg: the message to be logged - - Returns: - - """ - self.log_critical(self.fl_ctx, msg) - - def stop_task(self, reason: str): - """Stop the current task. - This method is to be called by the Learner's training or validation code when it runs into - a situation that the task processing cannot continue. - - Args: - reason: why the task cannot continue - - Returns: - - """ - self.log_error(self.fl_ctx, f"Task stopped: {reason}") - raise TaskExecutionError(reason) - - def initialize(self): - """Called by the framework to initialize the Learner object. - This is called before the Learner can train or validate. - This is called only once. - - """ - pass def train(self, model: FLModel) -> Union[str, FLModel]: """Called by the framework to perform training. Can be called many times during the lifetime of the Learner. @@ -222,23 +67,3 @@ def configure(self, model: FLModel): """ pass - - def abort(self): - """Called by the framework for the Learner to gracefully abort the current task. - - This could be caused by multiple reasons: - - user issued the abort command to stop the whole job - - Controller runs into some condition that requires the job to be aborted - """ - pass - - def finalize(self): - """Called by the framework to finalize the Learner (close/release resources gracefully) when - the job is finished. - - After this call, the Learner will be destroyed. - - Args: - - """ - pass diff --git a/nvflare/app_common/app_constant.py b/nvflare/app_common/app_constant.py index a261b1e049..73928fd95b 100644 --- a/nvflare/app_common/app_constant.py +++ b/nvflare/app_common/app_constant.py @@ -14,7 +14,6 @@ class ExecutorTasks: - TRAIN = "train" VALIDATE = "validate" CROSS_VALIDATION = "__cross_validation" @@ -23,7 +22,6 @@ class ExecutorTasks: class AppConstants(object): - CONFIG_PATH = "config_path" MODEL_NETWORK = "model_network" MULTI_GPU = "multi_gpu" @@ -119,27 +117,25 @@ class AppConstants(object): SUBMIT_MODEL_NAME = "submit_model_name" VALIDATE_TYPE = "_validate_type" + CLIENT_UNKNOWN = "unknown" -class EnvironmentKey(object): +class EnvironmentKey(object): CHECKPOINT_DIR = "APP_CKPT_DIR" CHECKPOINT_FILE_NAME = "APP_CKPT" class DefaultCheckpointFileName(object): - GLOBAL_MODEL = "FL_global_model.pt" BEST_GLOBAL_MODEL = "best_FL_global_model.pt" class ModelName(object): - BEST_MODEL = "best_model" FINAL_MODEL = "final_model" class ModelFormat(object): - PT_CHECKPOINT = "pt_checkpoint" TORCH_SCRIPT = "torch_script" PT_ONNX = "pt_onnx" @@ -148,13 +144,11 @@ class ModelFormat(object): class ValidateType(object): - BEFORE_TRAIN_VALIDATE = "before_train_validate" MODEL_VALIDATE = "model_validate" class AlgorithmConstants(object): - SCAFFOLD_CTRL_DIFF = "scaffold_c_diff" SCAFFOLD_CTRL_GLOBAL = "scaffold_c_global" SCAFFOLD_CTRL_AGGREGATOR_ID = "scaffold_ctrl_aggregator" diff --git a/nvflare/app_common/utils/fl_component_wrapper.py b/nvflare/app_common/utils/fl_component_wrapper.py new file mode 100644 index 0000000000..bff4fc9729 --- /dev/null +++ b/nvflare/app_common/utils/fl_component_wrapper.py @@ -0,0 +1,214 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import TaskExecutionError + + +class FLComponentWrapper(FLComponent): + STATE = None + + def __init__(self): + super().__init__() + self.engine = None + self.fl_ctx = None + self.workspace = None + self.shareable = None + self.args = None + self.site_name = None + self.job_id = None + self.app_root = None + self.job_root = None + self.workspace_root = None + self.abort_signal = None + self.current_round = 0 + self.total_rounds = 0 + + def is_aborted(self) -> bool: + """Check whether the task has been asked to abort by the framework. + + Returns: whether the task has been asked to abort by the framework + + """ + return self.abort_signal and self.abort_signal.triggered + + def get_shareable_header(self, key: str, default=None): + """Convenience method for getting specified header from the shareable. + + Args: + key: name of the header + default: default value if the header doesn't exist + + Returns: value of the header if it exists in the shareable; or the specified default if it doesn't. + + """ + if not self.shareable: + return default + return self.shareable.get_header(key, default) + + def get_context_prop(self, key: str, default=None): + """Convenience method for getting specified property from the FL Context. + + Args: + key: name of the property + default: default value if the prop doesn't exist in FL Context + + Returns: value of the prop if it exists in the context; or the specified default if it doesn't. + + """ + if not self.fl_ctx: + return default + assert isinstance(self.fl_ctx, FLContext) + return self.fl_ctx.get_prop(key, default) + + def get_component(self, component_id: str) -> Any: + """Get the specified component from the context + + Args: + component_id: ID of the component + + Returns: the specified component if it is defined; or None if not. + + """ + if self.engine: + return self.engine.get_component(component_id) + else: + return None + + def debug(self, msg: str): + """Convenience method for logging a DEBUG message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_debug(self.fl_ctx, msg) + + def info(self, msg: str): + """Convenience method for logging an INFO message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_info(self.fl_ctx, msg) + + def error(self, msg: str): + """Convenience method for logging an ERROR message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_error(self.fl_ctx, msg) + + def warning(self, msg: str): + """Convenience method for logging a WARNING message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_warning(self.fl_ctx, msg) + + def exception(self, msg: str): + """Convenience method for logging an EXCEPTION message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_exception(self.fl_ctx, msg) + + def critical(self, msg: str): + """Convenience method for logging a CRITICAL message with contextual info + + Args: + msg: the message to be logged + + Returns: + + """ + self.log_critical(self.fl_ctx, msg) + + def stop_task(self, reason: str): + """Stop the current task. + This method is to be called by the Learner's training or validation code when it runs into + a situation that the task processing cannot continue. + + Args: + reason: why the task cannot continue + + Returns: + + """ + self.log_error(self.fl_ctx, f"Task stopped: {reason}") + raise TaskExecutionError(reason) + + def initialize(self): + """Called by the framework to initialize the Learner object. + This is called before the Learner can train or validate. + This is called only once. + + """ + pass + + def abort(self): + """Called by the framework for the Learner to gracefully abort the current task. + + This could be caused by multiple reasons: + - user issued the abort command to stop the whole job + - Controller runs into some condition that requires the job to be aborted + """ + pass + + def finalize(self): + """Called by the framework to finalize the Learner (close/release resources gracefully) when + the job is finished. + + After this call, the Learner will be destroyed. + + Args: + + """ + pass + + def event(self, event_type): + """Fires an event. + + Args: + event_type (str): The type of event. + """ + self.fire_event(event_type, self.fl_ctx) + + def panic(self, reason: str): + """Signals a fatal condition that could cause the RUN to end. + + Args: + reason (str): The reason for panic. + """ + self.system_panic(reason, self.fl_ctx) diff --git a/nvflare/app_common/utils/fl_model_utils.py b/nvflare/app_common/utils/fl_model_utils.py index ed685ddd78..4ce9d01d6c 100644 --- a/nvflare/app_common/utils/fl_model_utils.py +++ b/nvflare/app_common/utils/fl_model_utils.py @@ -193,3 +193,23 @@ def set_meta_prop(model: FLModel, key: str, value: Any): @staticmethod def get_configs(model: FLModel) -> Optional[dict]: return FLModelUtils.get_meta_prop(model, MetaKey.CONFIGS) + + @staticmethod + def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = True) -> FLModel: + if model.params_type != ParamsType.FULL: + raise RuntimeError( + f"params_type {model_update.params_type} of `model` not supported! Expected `ParamsType.FULL`." + ) + + if replace_meta: + model.meta = model_update.meta + else: + model.meta.update(model_update.meta) + if model_update.params_type == ParamsType.FULL: + model.params = model_update.params + elif model_update.params_type == ParamsType.DIFF: + for v_name, v_value in model_update.params.items(): + model.params[v_name] = model.params[v_name] + v_value + else: + raise RuntimeError(f"params_type {model_update.params_type} of `model_update` not supported!") + return model diff --git a/nvflare/app_common/workflows/base_fedavg.py b/nvflare/app_common/workflows/base_fedavg.py new file mode 100644 index 0000000000..d031998e35 --- /dev/null +++ b/nvflare/app_common/workflows/base_fedavg.py @@ -0,0 +1,146 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import List + +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.app_event_type import AppEventType +from nvflare.app_common.utils.fl_model_utils import FLModelUtils +from nvflare.security.logging import secure_format_exception + +from .model_controller import ModelController + + +class BaseFedAvg(ModelController): + """The base controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`. + + Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). + The model persistor (persistor_id) is used to load the initial global model which is sent to a list of clients. + Each client sends it's updated weights after local training which is aggregated. + Next, the global model is updated. + The model_persistor also saves the model after training. + + Provides the default implementations for the follow routines: + - def sample_clients(self, min_clients) + - def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel + - def update_model(self, aggr_result) + + The `run` routine needs to be implemented by the derived class: + + - def run(self) + """ + + def sample_clients(self, min_clients): + """Called by the `run` routine to get a list of available clients. + + Args: + min_clients: number of clients to return. + + Returns: list of clients. + + """ + self._min_clients = min_clients + + clients = self.engine.get_clients() + if len(clients) < self._min_clients: + self._min_clients = len(clients) + + if self._min_clients < len(clients): + random.shuffle(clients) + clients = clients[0 : self._min_clients] + + return clients + + @staticmethod + def _check_results(results: List[FLModel]): + empty_clients = [] + for _result in results: + if not _result.params: + empty_clients.append(_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN)) + + if len(empty_clients) > 0: + raise ValueError(f"Result from client(s) {empty_clients} is empty!") + + @staticmethod + def _aggregate_fn(results: List[FLModel]) -> FLModel: + aggregation_helper = WeightedAggregationHelper() + for _result in results: + aggregation_helper.add( + data=_result.params, + weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), + contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), + contribution_round=_result.meta.get("current_round", None), + ) + + aggregated_dict = aggregation_helper.get_result() + + aggr_result = FLModel( + params=aggregated_dict, + params_type=results[0].params_type, + meta={"nr_aggregated": len(results), "current_round": results[0].meta["current_round"]}, + ) + return aggr_result + + def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel: + """Called by the `run` routine to aggregate the training results of clients. + + Args: + results: a list of FLModel containing training results of the clients. + aggregate_fn: a function that turns the list of FLModel into one resulting (aggregated) FLModel. + + Returns: aggregated FLModel. + + """ + self.debug("Start aggregation.") + self.event(AppEventType.BEFORE_AGGREGATION) + self._check_results(results) + + if not aggregate_fn: + aggregate_fn = self._aggregate_fn + + self.info(f"aggregating {len(results)} update(s) at round {self._current_round}") + try: + aggr_result = aggregate_fn(results) + except Exception as e: + error_msg = f"Exception in aggregate call: {secure_format_exception(e)}" + self.exception(error_msg) + self.panic(error_msg) + return FLModel() + self._results = [] + + self.fl_ctx.set_prop(AppConstants.AGGREGATION_RESULT, aggr_result, private=True, sticky=False) + self.event(AppEventType.AFTER_AGGREGATION) + self.debug("End aggregation.") + + return aggr_result + + def update_model(self, aggr_result): + """Called by the `run` routine to update the current global model (self.model) given the aggregated result. + + Args: + aggr_result: aggregated FLModel. + + Returns: None. + + """ + self.event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE) + + self.model = FLModelUtils.update_model(self.model, aggr_result) + + self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.model, private=True, sticky=True) + self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE) diff --git a/nvflare/app_common/workflows/fedavg.py b/nvflare/app_common/workflows/fedavg.py new file mode 100644 index 0000000000..516c25e165 --- /dev/null +++ b/nvflare/app_common/workflows/fedavg.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base_fedavg import BaseFedAvg + + +class FedAvg(BaseFedAvg): + """Controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`. + Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629). + + Provides the implementations for the `run` routine, controlling the main workflow: + - def run(self) + + The parent classes provide the default implementations for other routines. + + Args: + min_clients (int, optional): The minimum number of clients responses before + Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward + when all available clients have responded regardless of this value. Defaults to 1000. + num_rounds (int, optional): The total number of training rounds. Defaults to 5. + persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". + ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. + Defaults to False. + allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have + empty global weights at first round, such that clients start training from scratch without any global info. + Defaults to False. + task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5. + persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. + If n is 0 then no persist. + """ + + def run(self) -> None: + self.info("Start FedAvg.") + + for self._current_round in range(self._num_rounds): + self.info(f"Round {self._current_round} started.") + + clients = self.sample_clients(self._min_clients) + + results = self.send_model_and_wait(targets=clients, data=self.model) + + aggregate_results = self.aggregate( + results, aggregate_fn=None + ) # if no `aggregate_fn` provided, default `WeightedAggregationHelper` is used + + self.update_model(aggregate_results) + + self.save_model() + + self.info("Finished FedAvg.") diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py new file mode 100644 index 0000000000..5320a432be --- /dev/null +++ b/nvflare/app_common/workflows/model_controller.py @@ -0,0 +1,331 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import List, Union + +from nvflare.apis.client import Client +from nvflare.apis.controller_spec import OperatorMethod, TaskOperatorKey +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor +from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.app_event_type import AppEventType +from nvflare.app_common.utils.fl_component_wrapper import FLComponentWrapper +from nvflare.app_common.utils.fl_model_utils import FLModelUtils +from nvflare.fuel.utils.experimental import experimental +from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_int, check_str +from nvflare.security.logging import secure_format_exception +from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector + + +@experimental +class ModelController(Controller, FLComponentWrapper): + def __init__( + self, + min_clients: int = 1000, + num_rounds: int = 5, + persistor_id="", + ignore_result_error: bool = False, + allow_empty_global_weights: bool = False, + task_check_period: float = 0.5, + persist_every_n_rounds: int = 1, + ): + """FLModel based controller. + + Args: + min_clients (int, optional): The minimum number of clients responses before + Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward + when all available clients have responded regardless of this value. Defaults to 1000. + num_rounds (int, optional): The total number of training rounds. Defaults to 5. + persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". + ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. + Defaults to False. + allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have + empty global weights at first round, such that clients start training from scratch without any global info. + Defaults to False. + task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5. + persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. + If n is 0 then no persist. + """ + super().__init__(task_check_period=task_check_period) + + # Check arguments + check_positive_int("min_clients", min_clients) + check_non_negative_int("num_rounds", num_rounds) + check_non_negative_int("persist_every_n_rounds", persist_every_n_rounds) + check_str("persistor_id", persistor_id) + if not isinstance(task_check_period, (int, float)): + raise TypeError(f"task_check_period must be an int or float but got {type(task_check_period)}") + elif task_check_period <= 0: + raise ValueError("task_check_period must be greater than 0.") + self._task_check_period = task_check_period + self.persistor_id = persistor_id + self.persistor = None + + # config data + self._min_clients = min_clients + self._num_rounds = num_rounds + self._persist_every_n_rounds = persist_every_n_rounds + self.ignore_result_error = ignore_result_error + self.allow_empty_global_weights = allow_empty_global_weights + + # workflow phases: init, train, validate + self._phase = AppConstants.PHASE_INIT + self._current_round = None + + # model related + self.model = None + self._results = [] + + def start_controller(self, fl_ctx: FLContext) -> None: + self.fl_ctx = fl_ctx + self.info("Initializing ModelController workflow.") + + if self.persistor_id: + self.persistor = self._engine.get_component(self.persistor_id) + if not isinstance(self.persistor, LearnablePersistor): + self.panic( + f"Model Persistor {self.persistor_id} must be a LearnablePersistor type object, " + f"but got {type(self.persistor)}" + ) + return + + # initialize global model + if self.persistor: + global_weights = self.persistor.load(self.fl_ctx) + + if not isinstance(global_weights, ModelLearnable): + self.panic( + f"Expected global weights to be of type `ModelLearnable` but received {type(global_weights)}" + ) + return + + if global_weights.is_empty(): + if not self.allow_empty_global_weights: + # if empty not allowed, further check whether it is available from fl_ctx + global_weights = self.fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) + + if not global_weights.is_empty(): + self.model = FLModel( + params_type=ParamsType.FULL, + params=global_weights[ModelLearnableKey.WEIGHTS], + meta=global_weights[ModelLearnableKey.META], + ) + elif self.allow_empty_global_weights: + self.model = FLModel(params_type=ParamsType.FULL, params={}) + else: + self.panic( + f"Neither `persistor` {self.persistor_id} or `fl_ctx` returned a global model! If this was intended, set `self.allow_empty_global_weights` to `True`." + ) + return + else: + self.model = FLModel(params_type=ParamsType.FULL, params={}) + + self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.model, private=True, sticky=True) + self.event(AppEventType.INITIAL_MODEL_LOADED) + + self.engine = self.fl_ctx.get_engine() + self.initialize() + + def _build_shareable(self, data: FLModel = None) -> Shareable: + if not data: # if no data is given, send self.model + data = self.model + + data_shareable: Shareable = FLModelUtils.to_shareable(data) + data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round) + data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) + data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round) + + return data_shareable + + def send_model_and_wait( + self, + targets: Union[List[Client], List[str], None] = None, + data: FLModel = None, + task_name=AppConstants.TASK_TRAIN, + timeout: int = 0, + wait_time_after_min_received: int = 10, + ) -> List: + """Send the current global model or given data to a list of targets + + The task is scheduled into a task list. Clients can request tasks and controller will dispatch the task to eligible clients. + + Args: + targets: the list of eligible clients or client names or None (all clients). Defaults to None. + data: FLModel to be sent to clients. If no data is given, send `self.model`. + task_name (str, optional): Name of the train task. Defaults to "train". + timeout (int, optional): Time to wait for clients to do local training. Defaults to 0, i.e., never time out. + wait_time_after_min_received (int, optional): Time to wait before beginning aggregation after + minimum number of clients responses has been received. Defaults to 10. + """ + + if not isinstance(task_name, str): + raise TypeError("train_task_name must be a string but got {}".format(type(task_name))) + check_non_negative_int("timeout", timeout) + check_non_negative_int("wait_time_after_min_received", wait_time_after_min_received) + + # Create train_task + data_shareable = self._build_shareable(data) + + operator = { + TaskOperatorKey.OP_ID: task_name, + TaskOperatorKey.METHOD: OperatorMethod.BROADCAST, + TaskOperatorKey.TIMEOUT: timeout, + } + + train_task = Task( + name=task_name, + data=data_shareable, + operator=operator, + props={}, + timeout=timeout, + before_task_sent_cb=self._prepare_task_data, + result_received_cb=self._process_result, + ) + + self._results = [] # reset results list + self.info(f"Sending task {task_name} to {[client.name for client in targets]}") + self.broadcast_and_wait( + task=train_task, + targets=targets, + min_responses=self._min_clients, + wait_time_after_min_received=wait_time_after_min_received, + fl_ctx=self.fl_ctx, + abort_signal=self.abort_signal, + ) + + if targets is not None: + if len(self._results) != self._min_clients: + self.warning( + f"Number of results ({len(self._results)}) is different from min_clients ({self._min_clients})." + ) + + # de-refernce the internel results before returning + results = self._results + self._results = [] + return results + + def _prepare_task_data(self, client_task: ClientTask, fl_ctx: FLContext) -> None: + fl_ctx.set_prop(AppConstants.TRAIN_SHAREABLE, client_task.task.data, private=True, sticky=False) + self.event(AppEventType.BEFORE_TRAIN_TASK) + + def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: + self.fl_ctx = fl_ctx + result = client_task.result + client_name = client_task.client.name + + self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) + + # Turn result into FLModel + result_model = FLModelUtils.from_shareable(result) + result_model.meta["client_name"] = client_name + result_model.meta["current_round"] = self._current_round + result_model.meta["total_rounds"] = self._num_rounds + + self._results.append(result_model) + + # Cleanup task result + client_task.result = None + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ) -> None: + if self._phase == AppConstants.PHASE_TRAIN and task_name == task_name: + self._accept_train_result(client_name=client.name, result=result, fl_ctx=fl_ctx) + self.info(f"Result of unknown task {task_name} sent to aggregator.") + else: + self.error("Ignoring result from unknown task.") + + def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLContext): + self.fl_ctx = fl_ctx + rc = result.get_return_code() + + # Raise panic if bad peer context or execution exception. + if rc and rc != ReturnCode.OK: + if self.ignore_result_error: + self.warning( + f"Ignore the train result from {client_name} at round {self._current_round}. Train result error code: {rc}", + ) + else: + self.panic( + f"Result from {client_name} is bad, error code: {rc}. " + f"{self.__class__.__name__} exiting at round {self._current_round}." + ) + return + + self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True) + self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) + + @abstractmethod + def run(self): + """Main `run` routine called by the Controller's `control_flow` to execute the workflow. + + Returns: None. + + """ + raise NotImplementedError + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: + self._phase = AppConstants.PHASE_TRAIN + fl_ctx.set_prop(AppConstants.PHASE, self._phase, private=True, sticky=False) + fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False) + self.fl_ctx = fl_ctx + self.abort_signal = abort_signal + try: + self.info("Beginning model controller run.") + self.event(AppEventType.TRAINING_STARTED) + self._phase = AppConstants.PHASE_TRAIN + + self.run() + self._phase = AppConstants.PHASE_FINISHED + except Exception as e: + error_msg = f"Exception in model controller run: {secure_format_exception(e)}" + self.exception(error_msg) + self.panic(error_msg) + + def save_model(self): + if self.persistor: + if ( + self._persist_every_n_rounds != 0 and (self._current_round + 1) % self._persist_every_n_rounds == 0 + ) or self._current_round == self._num_rounds - 1: + self.info("Start persist model on server.") + self.event(AppEventType.BEFORE_LEARNABLE_PERSIST) + ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) + self.persistor.save(ml, self.fl_ctx) + self.event(AppEventType.AFTER_LEARNABLE_PERSIST) + self.info("End persist model on server.") + + def stop_controller(self, fl_ctx: FLContext): + self._phase = AppConstants.PHASE_FINISHED + self.fl_ctx = fl_ctx + self.finalize() + + def handle_event(self, event_type: str, fl_ctx: FLContext): + super().handle_event(event_type, fl_ctx) + if event_type == InfoCollector.EVENT_TYPE_GET_STATS: + collector = fl_ctx.get_prop(InfoCollector.CTX_KEY_STATS_COLLECTOR, None) + if collector: + if not isinstance(collector, GroupInfoCollector): + raise TypeError("collector must be GroupInfoCollector but got {}".format(type(collector))) + + collector.add_info( + group_name=self._name, + info={"phase": self._phase, "current_round": self._current_round, "num_rounds": self._num_rounds}, + ) diff --git a/nvflare/app_common/workflows/scaffold.py b/nvflare/app_common/workflows/scaffold.py new file mode 100644 index 0000000000..3d3dd3b86f --- /dev/null +++ b/nvflare/app_common/workflows/scaffold.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List + +import numpy as np + +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper +from nvflare.app_common.app_constant import AlgorithmConstants, AppConstants + +from .base_fedavg import BaseFedAvg + + +class Scaffold(BaseFedAvg): + """Controller for Scaffold Workflow. *Note*: This class is based on the experimental `ModelController`. + Implements [SCAFFOLD](https://proceedings.mlr.press/v119/karimireddy20a.html). + + Provides the implementations for the `run` routine, controlling the main workflow: + - def run(self) + + The parent classes provide the default implementations for other routines. + + Args: + min_clients (int, optional): The minimum number of clients responses before + Workflow starts to wait for `wait_time_after_min_received`. Note that the workflow will move forward + when all available clients have responded regardless of this value. Defaults to 1000. + num_rounds (int, optional): The total number of training rounds. Defaults to 5. + persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". + ignore_result_error (bool, optional): whether this controller can proceed if client result has errors. + Defaults to False. + allow_empty_global_weights (bool, optional): whether to allow empty global weights. Some pipelines can have + empty global weights at first round, such that clients start training from scratch without any global info. + Defaults to False. + task_check_period (float, optional): interval for checking status of tasks. Defaults to 0.5. + persist_every_n_rounds (int, optional): persist the global model every n rounds. Defaults to 1. + If n is 0 then no persist. + """ + + def initialize(self): + super().initialize() + self._global_ctrl_weights = copy.deepcopy(self.model.params) + # Initialize correction term with zeros + for k in self._global_ctrl_weights.keys(): + self._global_ctrl_weights[k] = np.zeros_like(self._global_ctrl_weights[k]) + + def run(self) -> None: + self.info("Start FedAvg.") + + for self._current_round in range(self._num_rounds): + self.info(f"Round {self._current_round} started.") + + clients = self.sample_clients(self._min_clients) + + # Add SCAFFOLD global control terms to global model meta + global_model = self.model + global_model.meta[AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL] = self._global_ctrl_weights + + results = self.send_model_and_wait(targets=clients, data=global_model) + + aggregate_results = self.aggregate(results, aggregate_fn=scaffold_aggregate_fn) + + self.update_model(aggregate_results) + + # update SCAFFOLD global controls + ctr_diff = aggregate_results.meta[AlgorithmConstants.SCAFFOLD_CTRL_DIFF] + for v_name, v_value in ctr_diff.items(): + self._global_ctrl_weights[v_name] += v_value + + self.save_model() + + self.info("Finished FedAvg.") + + +def scaffold_aggregate_fn(results: List[FLModel]) -> FLModel: + # aggregates both the model weights and the SCAFFOLD control terms + + aggregation_helper = WeightedAggregationHelper() + crtl_aggregation_helper = WeightedAggregationHelper() + for _result in results: + aggregation_helper.add( + data=_result.params, + weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), + contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), + contribution_round=_result.meta.get("current_round", None), + ) + crtl_aggregation_helper.add( + data=_result.meta[AlgorithmConstants.SCAFFOLD_CTRL_DIFF], + weight=_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), + contributor_name=_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), + contribution_round=_result.meta.get("current_round", None), + ) + + aggregated_dict = aggregation_helper.get_result() + + aggr_result = FLModel( + params=aggregated_dict, + params_type=results[0].params_type, + meta={ + AlgorithmConstants.SCAFFOLD_CTRL_DIFF: crtl_aggregation_helper.get_result(), + "nr_aggregated": len(results), + "current_round": results[0].meta["current_round"], + }, + ) + + return aggr_result