Skip to content

Commit e218b6c

Browse files
Refactor Job API (#2799)
* refactor fed job api * improve docstrings * refactor fed job api * improve docstrings * polish changes * fix bugs; check args * fix getattr * refactor * fixes and cleanup * added ccwf and flower jobs * added optional model selector for fed-avg * update hello-world examples * formatting and updates * address feedback * remove unnecessary stuff * add license text * detect duplicate executor --------- Co-authored-by: Yan Cheng <[email protected]>
1 parent ed1633a commit e218b6c

38 files changed

+1510
-1056
lines changed

examples/advanced/job_api/pt/cyclic_cc_script_executor_cifar10.py

+12-20
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,26 @@
1414

1515
from src.net import Net
1616

17-
from nvflare import FedJob, ScriptExecutor
18-
from nvflare.app_common.ccwf import CyclicClientController, CyclicServerController
17+
from nvflare.app_common.ccwf.ccwf_job import CCWFJob, CyclicClientConfig, CyclicServerConfig
1918
from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator
19+
from nvflare.app_common.executors.script_executor import ScriptExecutor
2020
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
2121

2222
if __name__ == "__main__":
2323
n_clients = 2
2424
num_rounds = 3
2525
train_script = "src/cifar10_fl.py"
2626

27-
job = FedJob(name="cifar10_cyclic")
27+
job = CCWFJob(name="cifar10_cyclic")
2828

29-
controller = CyclicServerController(num_rounds=num_rounds, max_status_report_interval=300)
30-
job.to(controller, "server")
31-
32-
for i in range(n_clients):
33-
executor = ScriptExecutor(
34-
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
35-
)
36-
job.to(executor, f"site-{i}", tasks=["train"], gpu=0)
37-
38-
# Add client-side controller for cyclic workflow
39-
executor = CyclicClientController()
40-
job.to(executor, f"site-{i}", tasks=["cyclic_*"])
41-
42-
# In swarm learning, each client uses a model persistor and shareable_generator
43-
job.to(PTFileModelPersistor(model=Net()), f"site-{i}", id="persistor")
44-
job.to(SimpleModelShareableGenerator(), f"site-{i}", id="shareable_generator")
29+
job.add_cyclic(
30+
server_config=CyclicServerConfig(num_rounds=num_rounds, max_status_report_interval=300),
31+
client_config=CyclicClientConfig(
32+
executor=ScriptExecutor(task_script_path=train_script),
33+
persistor=PTFileModelPersistor(model=Net()),
34+
shareable_generator=SimpleModelShareableGenerator(),
35+
),
36+
)
4537

4638
# job.export_job("/tmp/nvflare/jobs/job_config")
47-
job.simulator_run("/tmp/nvflare/jobs/workdir")
39+
job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients, gpu="0")

examples/advanced/job_api/pt/fedavg_model_learner_xsite_val_cifar10.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
import os
1616
import sys
1717

18-
sys.path.insert(0, os.path.join(os.getcwd(), "..", "..", "advanced", "cifar10"))
18+
sys.path.insert(0, os.path.join(os.getcwd(), "..", "..", "cifar10"))
1919

2020
from pt.learners.cifar10_model_learner import CIFAR10ModelLearner
2121
from pt.networks.cifar10_nets import ModerateCNN
2222
from pt.utils.cifar10_data_splitter import Cifar10DataSplitter
2323
from pt.utils.cifar10_data_utils import load_cifar10_data
2424

25-
from nvflare import FedAvg, FedJob
25+
from nvflare import FedJob
2626
from nvflare.app_common.executors.model_learner_executor import ModelLearnerExecutor
2727
from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval
28+
from nvflare.app_common.workflows.fedavg import FedAvg
29+
from nvflare.app_opt.pt.job_config.model import PTModel
2830

2931
if __name__ == "__main__":
3032
n_clients = 2
@@ -53,12 +55,12 @@
5355
job.to(data_splitter, "server")
5456

5557
# Define the initial global model and send to server
56-
job.to(ModerateCNN(), "server")
58+
job.to(PTModel(ModerateCNN()), "server")
5759

5860
for i in range(n_clients):
5961
learner = CIFAR10ModelLearner(train_idx_root=train_split_root, aggregation_epochs=aggregation_epochs, lr=0.01)
6062
executor = ModelLearnerExecutor(learner_id=job.as_id(learner))
61-
job.to(executor, f"site-{i+1}", gpu=0) # data splitter assumes client names start from 1
63+
job.to(executor, f"site-{i+1}") # data splitter assumes client names start from 1
6264

6365
# job.export_job("/tmp/nvflare/jobs/job_config")
64-
job.simulator_run("/tmp/nvflare/jobs/workdir")
66+
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

examples/advanced/job_api/pt/fedavg_script_executor_cifar10.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from src.net import Net
1616

17-
from nvflare import FedAvg, FedJob, ScriptExecutor
17+
from nvflare.app_common.executors.script_executor import ScriptExecutor
18+
from nvflare.app_common.workflows.fedavg import FedAvg
19+
from nvflare.app_opt.pt.job_config.model import PTModel
20+
21+
# from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
22+
from nvflare.job_config.api import FedJob
1823

1924
if __name__ == "__main__":
2025
n_clients = 2
@@ -32,16 +37,18 @@
3237
# job.to_server(controller)
3338

3439
# Define the initial global model and send to server
35-
job.to(Net(), "server")
36-
# job.to_server(Net())
40+
job.to(PTModel(Net()), "server")
41+
42+
# Note: We can optionally replace the above code with the FedAvgJob, which is a pattern to simplify FedAvg job creations
43+
# job = FedAvgJob(name="cifar10_fedavg", num_rounds=num_rounds, n_clients=n_clients, initial_model=Net())
3744

3845
# Add clients
3946
for i in range(n_clients):
4047
executor = ScriptExecutor(
4148
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
4249
)
43-
job.to(executor, f"site-{i}", gpu=0)
44-
# job.to_clients(executor)
50+
job.to(executor, target=f"site-{i}")
51+
# job.to_clients(executor)
4552

4653
# job.export_job("/tmp/nvflare/jobs/job_config")
47-
job.simulator_run("/tmp/nvflare/jobs/workdir")
54+
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

examples/advanced/job_api/pt/fedavg_script_executor_dp_filter_cifar10.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,25 @@
1414

1515
from src.net import Net
1616

17-
from nvflare import FedAvg, FedJob, FilterType, ScriptExecutor
17+
from nvflare import FilterType
18+
from nvflare.app_common.executors.script_executor import ScriptExecutor
1819
from nvflare.app_common.filters.percentile_privacy import PercentilePrivacy
20+
from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
1921

2022
if __name__ == "__main__":
2123
n_clients = 2
2224
num_rounds = 2
2325
train_script = "src/cifar10_fl.py"
2426

25-
job = FedJob(name="cifar10_fedavg_privacy")
26-
27-
# Define the controller workflow and send to server
28-
controller = FedAvg(
29-
num_clients=n_clients,
30-
num_rounds=num_rounds,
31-
)
32-
job.to(controller, "server")
33-
34-
# Define the initial global model and send to server
35-
job.to(Net(), "server")
27+
job = FedAvgJob(name="cifar10_fedavg_privacy", num_rounds=num_rounds, n_clients=n_clients, initial_model=Net())
3628

3729
for i in range(n_clients):
3830
executor = ScriptExecutor(task_script_path=train_script, task_script_args="")
39-
job.to(executor, f"site-{i}", tasks=["train"], gpu=0)
31+
job.to(executor, f"site-{i}", tasks=["train"])
4032

4133
# add privacy filter.
4234
pp_filter = PercentilePrivacy(percentile=10, gamma=0.01)
4335
job.to(pp_filter, f"site-{i}", tasks=["train"], filter_type=FilterType.TASK_RESULT)
4436

4537
# job.export_job("/tmp/nvflare/jobs/job_config")
46-
job.simulator_run("/tmp/nvflare/jobs/workdir")
38+
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

examples/advanced/job_api/pt/fedavg_script_executor_lightning_cifar10.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,22 @@
1414

1515
from src.lit_net import LitNet
1616

17-
from nvflare import FedAvg, FedJob, ScriptExecutor
17+
from nvflare.app_common.executors.script_executor import ScriptExecutor
18+
from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob
1819

1920
if __name__ == "__main__":
2021
n_clients = 2
2122
num_rounds = 2
2223
train_script = "src/cifar10_lightning_fl.py"
2324

24-
job = FedJob(name="cifar10_fedavg_lightning")
25-
26-
# Define the controller workflow and send to server
27-
controller = FedAvg(
28-
num_clients=n_clients,
29-
num_rounds=num_rounds,
30-
)
31-
job.to(controller, "server")
32-
33-
# Define the initial global model and send to server
34-
job.to(LitNet(), "server")
25+
job = FedAvgJob(name="cifar10_fedavg_lightning", num_rounds=num_rounds, n_clients=n_clients, initial_model=LitNet())
3526

3627
# Add clients
3728
for i in range(n_clients):
3829
executor = ScriptExecutor(
3930
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
4031
)
41-
job.to(executor, f"site-{i}", gpu=0)
32+
job.to(executor, f"site-{i}")
4233

4334
# job.export_job("/tmp/nvflare/jobs/job_config")
44-
job.simulator_run("/tmp/nvflare/jobs/workdir")
35+
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

examples/advanced/job_api/pt/swarm_script_executor_cifar10.py

+14-40
Original file line numberDiff line numberDiff line change
@@ -14,56 +14,30 @@
1414

1515
from src.net import Net
1616

17-
from nvflare import FedJob, ScriptExecutor
1817
from nvflare.apis.dxo import DataKind
1918
from nvflare.app_common.aggregators.intime_accumulate_model_aggregator import InTimeAccumulateWeightedAggregator
20-
from nvflare.app_common.ccwf import (
21-
CrossSiteEvalClientController,
22-
CrossSiteEvalServerController,
23-
SwarmClientController,
24-
SwarmServerController,
25-
)
19+
from nvflare.app_common.ccwf.ccwf_job import CCWFJob, CrossSiteEvalConfig, SwarmClientConfig, SwarmServerConfig
2620
from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator
21+
from nvflare.app_common.executors.script_executor import ScriptExecutor
2722
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
2823

2924
if __name__ == "__main__":
3025
n_clients = 2
3126
num_rounds = 3
3227
train_script = "src/train_eval_submit.py"
3328

34-
job = FedJob(name="cifar10_swarm")
35-
36-
controller = SwarmServerController(
37-
num_rounds=num_rounds,
29+
job = CCWFJob(name="cifar10_swarm")
30+
aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS)
31+
job.add_swarm(
32+
server_config=SwarmServerConfig(num_rounds=num_rounds),
33+
client_config=SwarmClientConfig(
34+
executor=ScriptExecutor(task_script_path=train_script, evaluate_task_name="validate"),
35+
aggregator=aggregator,
36+
persistor=PTFileModelPersistor(model=Net()),
37+
shareable_generator=SimpleModelShareableGenerator(),
38+
),
39+
cse_config=CrossSiteEvalConfig(eval_task_timeout=300),
3840
)
39-
job.to(controller, "server")
40-
controller = CrossSiteEvalServerController(eval_task_timeout=300)
41-
job.to(controller, "server")
42-
43-
# Define the initial server model
44-
job.to(Net(), "server")
45-
46-
for i in range(n_clients):
47-
executor = ScriptExecutor(task_script_path=train_script, evaluate_task_name="validate")
48-
job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])
49-
50-
# In swarm learning, each client acts also as an aggregator
51-
aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS)
52-
53-
# In swarm learning, each client uses a model persistor and shareable_generator
54-
persistor = PTFileModelPersistor(model=Net())
55-
shareable_generator = SimpleModelShareableGenerator()
56-
57-
persistor_id = job.as_id(persistor)
58-
client_controller = SwarmClientController(
59-
aggregator_id=job.as_id(aggregator),
60-
persistor_id=persistor_id,
61-
shareable_generator_id=job.as_id(shareable_generator),
62-
)
63-
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])
64-
65-
client_controller = CrossSiteEvalClientController(persistor_id=persistor_id)
66-
job.to(client_controller, f"site-{i}", tasks=["cse_*"])
6741

6842
# job.export_job("/tmp/nvflare/jobs/job_config")
69-
job.simulator_run("/tmp/nvflare/jobs/workdir")
43+
job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients, gpu="0")

examples/advanced/job_api/sklearn/kmeans_script_executor_higgs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from src.kmeans_assembler import KMeansAssembler
1919
from src.split_csv import distribute_header_file, split_csv
2020

21-
from nvflare import FedJob, ScriptExecutor
21+
from nvflare import FedJob
2222
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
23+
from nvflare.app_common.executors.script_executor import ScriptExecutor
2324
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
2425
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
2526
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor

0 commit comments

Comments
 (0)