Skip to content

Commit

Permalink
fix ccwf with client api, params_converter based on task (#2260)
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored Jan 5, 2024
1 parent f66a4f6 commit 21c3f1d
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 22 deletions.
4 changes: 2 additions & 2 deletions examples/hello-world/step-by-step/cifar10/code/fl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def evaluate(input_weights):

# (4) receive FLModel from NVFlare
input_model = flare.receive()
client_id = flare.system_info().get("site_name", None)
client_id = flare.get_site_name()

# Based on different "task" we will do different things
# for "train" task (flare.is_train()) we use the received model to do training and/or evaluation
Expand All @@ -104,7 +104,7 @@ def evaluate(input_weights):
# for "submit_model" task (flare.is_submit_model()) we just need to send back the local model
# (5) performing train task on received model
if flare.is_train():
print(f"({client_id}) round={input_model.current_round}/{input_model.total_rounds-1}")
print(f"({client_id}) current_round={input_model.current_round}, total_rounds={input_model.total_rounds}")

# (5.1) loads model from NVFlare
net.load_state_dict(input_model.params)
Expand Down
5 changes: 0 additions & 5 deletions job_templates/sklearn_linear/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

# if launch_once is true, the executor will only call launcher.launch_task once
# for the whole job, if launch_once is false, the executor will call launcher.launch_task
# everytime it receives a task from server
launch_once = true
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions job_templates/sklearn_svm/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

# if launch_once is true, the executor will only call launcher.launch_task once
# for the whole job, if launch_once is false, the executor will call launcher.launch_task
# everytime it receives a task from server
launch_once = true
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions job_templates/swarm_cse_pt/config_fed_client.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
format_version = 2
# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "train.py"
# Additional arguments needed by the training code.
app_config = ""
# Client Computing Executors.
executors = [
{
Expand Down
14 changes: 9 additions & 5 deletions nvflare/app_common/abstract/params_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, List

from nvflare.apis.dxo import from_shareable
from nvflare.apis.filter import Filter
Expand All @@ -22,10 +22,14 @@


class ParamsConverter(Filter, ABC):
def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
dxo = from_shareable(shareable)
dxo.data = self.convert(dxo.data, fl_ctx)
dxo.update_shareable(shareable)
def __init__(self, supported_tasks: List[str] = None):
self.supported_tasks = supported_tasks

def process(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
if not self.supported_tasks or task_name in self.supported_tasks:
dxo = from_shareable(shareable)
dxo.data = self.convert(dxo.data, fl_ctx)
dxo.update_shareable(shareable)
return shareable

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions nvflare/app_common/ccwf/cse_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def do_eval(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal)

model_to_validate = reply
model_to_validate.set_header(AppConstants.VALIDATE_TYPE, ValidateType.MODEL_VALIDATE)
model_to_validate.set_header(FLContextKey.TASK_NAME, self.validation_task_name)
if model_type == ModelType.LOCAL:
model_to_validate.set_header(AppConstants.MODEL_OWNER, model_owner)

Expand Down Expand Up @@ -218,6 +219,7 @@ def _do_process_get_model_request(self, request: Shareable, fl_ctx: FLContext) -
if not self.local_model:
task_data = Shareable()
task_data.set_header(AppConstants.SUBMIT_MODEL_NAME, model_name)
task_data.set_header(FLContextKey.TASK_NAME, self.submit_model_task_name)

abort_signal = Signal()
try:
Expand Down
4 changes: 3 additions & 1 deletion nvflare/app_common/ccwf/cyclic_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import random

from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
Expand Down Expand Up @@ -88,6 +88,8 @@ def do_learn_task(self, name: str, data: Shareable, fl_ctx: FLContext, abort_sig
global_weights = self.shareable_generator.shareable_to_learnable(data, fl_ctx)
fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, global_weights, private=True, sticky=True)

data.set_header(FLContextKey.TASK_NAME, name)

# execute the task
result = self.execute_learn_task(data, fl_ctx, abort_signal)

Expand Down
2 changes: 2 additions & 0 deletions nvflare/app_common/ccwf/swarm_client_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abor

self.log_info(fl_ctx, f"Round {current_round} started.")

task_data.set_header(FLContextKey.TASK_NAME, name)

# Some shareable generators assume the base model (GLOBAL_MODEL) is always available, which is true for
# server-controlled fed-avg. But this is not true for swarm learning.
# To make these generators happy, we create an empty global model here if not present.
Expand Down
4 changes: 2 additions & 2 deletions nvflare/app_common/executors/launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

if self._from_nvflare_converter is not None:
shareable = self._from_nvflare_converter.process(shareable, fl_ctx)
shareable = self._from_nvflare_converter.process(task_name, shareable, fl_ctx)

result = super().execute(task_name, shareable, fl_ctx, abort_signal)

Expand All @@ -161,7 +161,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

if self._to_nvflare_converter is not None:
result = self._to_nvflare_converter.process(result, fl_ctx)
result = self._to_nvflare_converter.process(task_name, result, fl_ctx)

self._finalize_external_execution(task_name, shareable, fl_ctx, abort_signal)

Expand Down
9 changes: 7 additions & 2 deletions nvflare/app_opt/pt/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.client_api_launcher_executor import ClientAPILauncherExecutor
from nvflare.app_opt.pt.decomposers import TensorDecomposer
from nvflare.app_opt.pt.params_converter import NumpyToPTParamsConverter, PTToNumpyParamsConverter
Expand All @@ -26,6 +27,10 @@ def initialize(self, fl_ctx: FLContext) -> None:
self._params_exchange_format = ExchangeFormat.PYTORCH
super().initialize(fl_ctx)
if self._from_nvflare_converter is None:
self._from_nvflare_converter = NumpyToPTParamsConverter()
self._from_nvflare_converter = NumpyToPTParamsConverter(
[AppConstants.TASK_TRAIN, AppConstants.TASK_VALIDATION]
)
if self._to_nvflare_converter is None:
self._to_nvflare_converter = PTToNumpyParamsConverter()
self._to_nvflare_converter = PTToNumpyParamsConverter(
[AppConstants.TASK_TRAIN, AppConstants.TASK_SUBMIT_MODEL]
)

0 comments on commit 21c3f1d

Please sign in to comment.